Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[util] Add generic torch device class #6174

Merged
merged 19 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions invokeai/app/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import get_torch_device_name
from invokeai.backend.util.devices import TorchDeviceSelect

from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
Expand Down Expand Up @@ -63,7 +63,7 @@
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")

torch_device_name = get_torch_device_name()
torch_device_name = TorchDeviceSelect().get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")


Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ExtraConditioningInfo,
SDXLConditioningInfo,
)
from invokeai.backend.util.devices import torch_dtype
from invokeai.backend.util.devices import TorchDeviceSelect

from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import CLIPField
Expand Down Expand Up @@ -89,7 +89,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
dtype_for_device_getter=TorchDeviceSelect(context).choose_torch_dtype,
truncate_long_prompts=False,
)

Expand Down Expand Up @@ -191,7 +191,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
dtype_for_device_getter=TorchDeviceSelect(context).choose_torch_dtype,
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
Expand Down
42 changes: 13 additions & 29 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,12 @@
image_resized_to_grid_as_tensor,
)
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device
from ...backend.util.devices import TorchDeviceSelect
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .controlnet_image_processors import ControlField
from .model import ModelIdentifierField, UNetField, VAEField

if choose_torch_device() == torch.device("mps"):
from torch import mps

DEFAULT_PRECISION = choose_precision(choose_torch_device())
DEFAULT_PRECISION = TorchDeviceSelect().choose_torch_dtype
RyanJDick marked this conversation as resolved.
Show resolved Hide resolved


@invocation_output("scheduler_output")
Expand Down Expand Up @@ -794,9 +791,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDeviceSelect.empty_cache()

name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed)
Expand Down Expand Up @@ -863,9 +858,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
vae.disable_tiling()

# clear memory as vae decode can request a lot
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDeviceSelect.empty_cache()

with torch.inference_mode():
# copied from diffusers pipeline
Expand All @@ -877,9 +870,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:

image = VaeImageProcessor.numpy_to_pil(np_image)[0]

torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDeviceSelect.empty_cache()

image_dto = context.images.save(image=image)

Expand Down Expand Up @@ -919,8 +910,7 @@ class ResizeLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)

# TODO:
device = choose_torch_device()
device = TorchDeviceSelect(context).choose_torch_device()

resized_latents = torch.nn.functional.interpolate(
latents.to(device),
Expand All @@ -931,9 +921,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()

TorchDeviceSelect.empty_cache()

name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
Expand All @@ -960,8 +949,7 @@ class ScaleLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)

# TODO:
device = choose_torch_device()
device = TorchDeviceSelect(context).choose_torch_device()

# resizing
resized_latents = torch.nn.functional.interpolate(
Expand All @@ -973,9 +961,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()
TorchDeviceSelect.empty_cache()

name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
Expand Down Expand Up @@ -1107,8 +1093,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.")

# TODO:
device = choose_torch_device()
device = TorchDeviceSelect(context).choose_torch_device()

def slerp(
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
Expand Down Expand Up @@ -1161,9 +1146,8 @@ def slerp(

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()

TorchDeviceSelect.empty_cache()

name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents)
Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/invocations/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX

from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.util.devices import TorchDeviceSelect
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Expand Down Expand Up @@ -46,7 +46,7 @@ def get_noise(
height // downsampling_factor,
width // downsampling_factor,
],
dtype=torch_dtype(device),
dtype=TorchDeviceSelect().choose_torch_dtype(device),
device=noise_device_type,
generator=generator,
).to("cpu")
Expand Down Expand Up @@ -118,7 +118,7 @@ def invoke(self, context: InvocationContext) -> NoiseOutput:
noise = get_noise(
width=self.width,
height=self.height,
device=choose_torch_device(),
device=TorchDeviceSelect(context).choose_torch_device(),
seed=self.seed,
use_cpu=self.use_cpu,
)
Expand Down
10 changes: 2 additions & 8 deletions invokeai/app/invocations/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import cv2
import numpy as np
import torch
from PIL import Image
from pydantic import ConfigDict

Expand All @@ -14,7 +13,7 @@
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDeviceSelect

from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
Expand All @@ -35,9 +34,6 @@
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
}

if choose_torch_device() == torch.device("mps"):
from torch import mps


@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
Expand Down Expand Up @@ -120,9 +116,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
upscaled_image = upscaler.upscale(cv2_image)
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")

torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDeviceSelect.empty_cache()

image_dto = context.images.save(image=pil_image)

Expand Down
10 changes: 5 additions & 5 deletions invokeai/app/services/model_install/model_install_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Union

import torch
import yaml
from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl
Expand Down Expand Up @@ -42,7 +43,7 @@
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.devices import choose_precision, choose_torch_device
from invokeai.backend.util.devices import TorchDeviceSelect

from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP,
Expand Down Expand Up @@ -634,11 +635,10 @@ def _next_id(self) -> int:
self._next_job_id += 1
return id

@staticmethod
def _guess_variant() -> Optional[ModelRepoVariant]:
def _guess_variant(self) -> Optional[ModelRepoVariant]:
"""Guess the best HuggingFace variant type to download."""
precision = choose_precision(choose_torch_device())
return ModelRepoVariant.FP16 if precision == "float16" else None
precision = TorchDeviceSelect().choose_torch_dtype
RyanJDick marked this conversation as resolved.
Show resolved Hide resolved
return ModelRepoVariant.FP16 if precision == torch.float16 else None

def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
return ModelInstallJob(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""

from typing import Optional

import torch
from typing_extensions import Self

from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDeviceSelect
from invokeai.backend.util.logging import InvokeAILogger

from ..config import InvokeAIAppConfig
Expand Down Expand Up @@ -67,7 +69,7 @@ def build_model_manager(
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_device: torch.device = choose_torch_device(),
execution_device: Optional[torch.device] = None,
) -> Self:
"""
Construct the model manager service instance.
Expand All @@ -82,7 +84,7 @@ def build_model_manager(
max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload,
logger=logger,
execution_device=execution_device,
execution_device=execution_device or TorchDeviceSelect().choose_torch_device(),
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService(
Expand Down
15 changes: 14 additions & 1 deletion invokeai/app/services/shared/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING, Optional, Union

from PIL.Image import Image
from torch import Tensor
from torch import Tensor, torch

from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
Expand Down Expand Up @@ -426,6 +426,19 @@ def search_by_attrs(
model_format=format,
)

def get_free_device(self) -> torch.device:
"""Return a free GPU for accelerated torch operations.
Args:
none

Returns:
A torch.dtype object.

Will raise a NotImplementedError until the multi-GPU support
PR is merged.
"""
return self._services.model_manager.load.ram_cache.get_execution_device()


class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig:
Expand Down
8 changes: 4 additions & 4 deletions invokeai/backend/image_util/depth_anything/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDeviceSelect
from invokeai.backend.util.logging import InvokeAILogger

config = get_config()
Expand Down Expand Up @@ -56,7 +56,7 @@ class DepthAnythingDetector:
def __init__(self) -> None:
self.model = None
self.model_size: Union[Literal["large", "base", "small"], None] = None
self.device = choose_torch_device()
self.device = TorchDeviceSelect().choose_torch_device()

def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
Expand All @@ -81,7 +81,7 @@ def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
self.model.eval()

self.model.to(choose_torch_device())
self.model.to(self.device)
return self.model

def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
Expand All @@ -94,7 +94,7 @@ def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:

image_height, image_width = np_image.shape[:2]
np_image = transform({"image": np_image})["image"]
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device())
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)

with torch.no_grad():
depth = self.model(tensor_image)
Expand Down
6 changes: 3 additions & 3 deletions invokeai/backend/image_util/dw_openpose/wholebody.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDeviceSelect

from .onnxdet import inference_detector
from .onnxpose import inference_pose
Expand All @@ -28,9 +28,9 @@

class Wholebody:
def __init__(self):
device = choose_torch_device()
device = TorchDeviceSelect().choose_torch_device()

providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]

DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
Expand Down
4 changes: 2 additions & 2 deletions invokeai/backend/image_util/infill_methods/lama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDeviceSelect


def norm_img(np_img):
Expand All @@ -29,7 +29,7 @@ def load_jit_model(url_or_path, device):

class LaMA:
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
device = choose_torch_device()
device = TorchDeviceSelect().choose_torch_device()
model_location = get_config().models_path / "core/misc/lama/lama.pt"

if not model_location.exists():
Expand Down
Loading