Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[util] Add generic torch device class #6174

Merged
merged 19 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions invokeai/app/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import get_torch_device_name
from invokeai.backend.util.devices import TorchDevice

from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
Expand Down Expand Up @@ -63,7 +63,7 @@
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")

torch_device_name = get_torch_device_name()
torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")


Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
ConditioningFieldData,
SDXLConditioningInfo,
)
from invokeai.backend.util.devices import torch_dtype
from invokeai.backend.util.devices import TorchDevice

from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import CLIPField
Expand Down Expand Up @@ -99,7 +99,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
)

Expand Down Expand Up @@ -193,7 +193,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
Expand Down
43 changes: 13 additions & 30 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,12 @@
image_resized_to_grid_as_tensor,
)
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device
from ...backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .controlnet_image_processors import ControlField
from .model import ModelIdentifierField, UNetField, VAEField

if choose_torch_device() == torch.device("mps"):
from torch import mps

DEFAULT_PRECISION = choose_precision(choose_torch_device())
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()


@invocation_output("scheduler_output")
Expand Down Expand Up @@ -959,9 +956,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()

name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
Expand Down Expand Up @@ -1028,9 +1023,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
vae.disable_tiling()

# clear memory as vae decode can request a lot
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()

with torch.inference_mode():
# copied from diffusers pipeline
Expand All @@ -1042,9 +1035,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:

image = VaeImageProcessor.numpy_to_pil(np_image)[0]

torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()

image_dto = context.images.save(image=image)

Expand Down Expand Up @@ -1083,9 +1074,7 @@ class ResizeLatentsInvocation(BaseInvocation):

def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)

# TODO:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()

resized_latents = torch.nn.functional.interpolate(
latents.to(device),
Expand All @@ -1096,9 +1085,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()

TorchDevice.empty_cache()

name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
Expand All @@ -1125,8 +1113,7 @@ class ScaleLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)

# TODO:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()

# resizing
resized_latents = torch.nn.functional.interpolate(
Expand All @@ -1138,9 +1125,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()

name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
Expand Down Expand Up @@ -1272,8 +1257,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.")

# TODO:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()

def slerp(
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
Expand Down Expand Up @@ -1326,9 +1310,8 @@ def slerp(

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()

TorchDevice.empty_cache()

name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents)
Expand Down
8 changes: 4 additions & 4 deletions invokeai/app/invocations/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX

from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.util.devices import TorchDevice
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Expand Down Expand Up @@ -46,7 +46,7 @@ def get_noise(
height // downsampling_factor,
width // downsampling_factor,
],
dtype=torch_dtype(device),
dtype=TorchDevice.choose_torch_dtype(device=device),
device=noise_device_type,
generator=generator,
).to("cpu")
Expand Down Expand Up @@ -111,14 +111,14 @@ class NoiseInvocation(BaseInvocation):

@field_validator("seed", mode="before")
def modulo_seed(cls, v):
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
"""Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
return v % (SEED_MAX + 1)

def invoke(self, context: InvocationContext) -> NoiseOutput:
noise = get_noise(
width=self.width,
height=self.height,
device=choose_torch_device(),
device=TorchDevice.choose_torch_device(),
seed=self.seed,
use_cpu=self.use_cpu,
)
Expand Down
10 changes: 2 additions & 8 deletions invokeai/app/invocations/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import cv2
import numpy as np
import torch
from PIL import Image
from pydantic import ConfigDict

Expand All @@ -14,7 +13,7 @@
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice

from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
Expand All @@ -35,9 +34,6 @@
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
}

if choose_torch_device() == torch.device("mps"):
from torch import mps


@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
Expand Down Expand Up @@ -120,9 +116,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
upscaled_image = upscaler.upscale(cv2_image)
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")

torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()

image_dto = context.images.save(image=pil_image)

Expand Down
57 changes: 43 additions & 14 deletions invokeai/app/services/config/config_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.0"
CONFIG_SCHEMA_VERSION = "4.0.1"


def get_default_ram_cache_size() -> float:
Expand Down Expand Up @@ -105,7 +105,7 @@ class InvokeAIAppConfig(BaseSettings):
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
Expand Down Expand Up @@ -370,6 +370,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
Expand All @@ -392,6 +395,28 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
return config


def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.0 config dictionary to a current config object.

Args:
config_dict: A dictionary of settings from a v4.0.0 config file.

Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
if k == "schema_version":
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config


def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.

Expand All @@ -418,17 +443,21 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
migrated_config.write_file(config_path)
return migrated_config
else:
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e

if loaded_config_dict["schema_version"] == "4.0.0":
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)

# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e


@lru_cache(maxsize=1)
Expand Down
10 changes: 5 additions & 5 deletions invokeai/app/services/model_install/model_install_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Union

import torch
import yaml
from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl
Expand Down Expand Up @@ -42,7 +43,7 @@
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.devices import choose_precision, choose_torch_device
from invokeai.backend.util.devices import TorchDevice

from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP,
Expand Down Expand Up @@ -634,11 +635,10 @@ def _next_id(self) -> int:
self._next_job_id += 1
return id

@staticmethod
def _guess_variant() -> Optional[ModelRepoVariant]:
def _guess_variant(self) -> Optional[ModelRepoVariant]:
"""Guess the best HuggingFace variant type to download."""
precision = choose_precision(choose_torch_device())
return ModelRepoVariant.FP16 if precision == "float16" else None
precision = TorchDevice.choose_torch_dtype()
return ModelRepoVariant.FP16 if precision == torch.float16 else None

def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
return ModelInstallJob(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""

from typing import Optional

import torch
from typing_extensions import Self

from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger

from ..config import InvokeAIAppConfig
Expand Down Expand Up @@ -67,7 +69,7 @@ def build_model_manager(
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_device: torch.device = choose_torch_device(),
execution_device: Optional[torch.device] = None,
) -> Self:
"""
Construct the model manager service instance.
Expand All @@ -82,7 +84,7 @@ def build_model_manager(
max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload,
logger=logger,
execution_device=execution_device,
execution_device=execution_device or TorchDevice.choose_torch_device(),
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService(
Expand Down
Loading
Loading