Skip to content

Commit

Permalink
[util] Add generic torch device class (#6174)
Browse files Browse the repository at this point in the history
* introduce new abstraction layer for GPU devices

* add unit test for device abstraction

* fix ruff

* convert TorchDeviceSelect into a stateless class

* move logic to select context-specific execution device into context API

* add mock hardware environments to pytest

* remove dangling mocker fixture

* fix unit test for running on non-CUDA systems

* remove unimplemented get_execution_device() call

* remove autocast precision

* Multiple changes:

1. Remove TorchDeviceSelect.get_execution_device(), as well as calls to
   context.models.get_execution_device().
2. Rename TorchDeviceSelect to TorchDevice
3. Added back the legacy public API defined in `invocation_api`, including
   choose_precision().
4. Added a config file migration script to accommodate removal of precision=autocast.

* add deprecation warnings to choose_torch_device() and choose_precision()

* fix test crash

* remove app_config argument from choose_torch_device() and choose_torch_dtype()

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
  • Loading branch information
lstein and Lincoln Stein authored Apr 15, 2024
1 parent 5a8489b commit e93f4d6
Show file tree
Hide file tree
Showing 20 changed files with 331 additions and 180 deletions.
4 changes: 2 additions & 2 deletions invokeai/app/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import get_torch_device_name
from invokeai.backend.util.devices import TorchDevice

from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
Expand Down Expand Up @@ -63,7 +63,7 @@
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")

torch_device_name = get_torch_device_name()
torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")


Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
ConditioningFieldData,
SDXLConditioningInfo,
)
from invokeai.backend.util.devices import torch_dtype
from invokeai.backend.util.devices import TorchDevice

from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import CLIPField
Expand Down Expand Up @@ -99,7 +99,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
)

Expand Down Expand Up @@ -193,7 +193,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
Expand Down
43 changes: 13 additions & 30 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,12 @@
image_resized_to_grid_as_tensor,
)
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device
from ...backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .controlnet_image_processors import ControlField
from .model import ModelIdentifierField, UNetField, VAEField

if choose_torch_device() == torch.device("mps"):
from torch import mps

DEFAULT_PRECISION = choose_precision(choose_torch_device())
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()


@invocation_output("scheduler_output")
Expand Down Expand Up @@ -959,9 +956,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()

name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
Expand Down Expand Up @@ -1028,9 +1023,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
vae.disable_tiling()

# clear memory as vae decode can request a lot
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()

with torch.inference_mode():
# copied from diffusers pipeline
Expand All @@ -1042,9 +1035,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:

image = VaeImageProcessor.numpy_to_pil(np_image)[0]

torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()

image_dto = context.images.save(image=image)

Expand Down Expand Up @@ -1083,9 +1074,7 @@ class ResizeLatentsInvocation(BaseInvocation):

def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)

# TODO:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()

resized_latents = torch.nn.functional.interpolate(
latents.to(device),
Expand All @@ -1096,9 +1085,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()

TorchDevice.empty_cache()

name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
Expand All @@ -1125,8 +1113,7 @@ class ScaleLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)

# TODO:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()

# resizing
resized_latents = torch.nn.functional.interpolate(
Expand All @@ -1138,9 +1125,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()

name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
Expand Down Expand Up @@ -1272,8 +1257,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.")

# TODO:
device = choose_torch_device()
device = TorchDevice.choose_torch_device()

def slerp(
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
Expand Down Expand Up @@ -1326,9 +1310,8 @@ def slerp(

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()

TorchDevice.empty_cache()

name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents)
Expand Down
8 changes: 4 additions & 4 deletions invokeai/app/invocations/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX

from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.util.devices import TorchDevice
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Expand Down Expand Up @@ -46,7 +46,7 @@ def get_noise(
height // downsampling_factor,
width // downsampling_factor,
],
dtype=torch_dtype(device),
dtype=TorchDevice.choose_torch_dtype(device=device),
device=noise_device_type,
generator=generator,
).to("cpu")
Expand Down Expand Up @@ -111,14 +111,14 @@ class NoiseInvocation(BaseInvocation):

@field_validator("seed", mode="before")
def modulo_seed(cls, v):
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
"""Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
return v % (SEED_MAX + 1)

def invoke(self, context: InvocationContext) -> NoiseOutput:
noise = get_noise(
width=self.width,
height=self.height,
device=choose_torch_device(),
device=TorchDevice.choose_torch_device(),
seed=self.seed,
use_cpu=self.use_cpu,
)
Expand Down
10 changes: 2 additions & 8 deletions invokeai/app/invocations/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import cv2
import numpy as np
import torch
from PIL import Image
from pydantic import ConfigDict

Expand All @@ -14,7 +13,7 @@
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice

from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
Expand All @@ -35,9 +34,6 @@
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
}

if choose_torch_device() == torch.device("mps"):
from torch import mps


@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
Expand Down Expand Up @@ -120,9 +116,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
upscaled_image = upscaler.upscale(cv2_image)
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")

torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
TorchDevice.empty_cache()

image_dto = context.images.save(image=pil_image)

Expand Down
57 changes: 43 additions & 14 deletions invokeai/app/services/config/config_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.0"
CONFIG_SCHEMA_VERSION = "4.0.1"


def get_default_ram_cache_size() -> float:
Expand Down Expand Up @@ -105,7 +105,7 @@ class InvokeAIAppConfig(BaseSettings):
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
Expand Down Expand Up @@ -370,6 +370,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
Expand All @@ -392,6 +395,28 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
return config


def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.0 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
if k == "schema_version":
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config


def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
Expand All @@ -418,17 +443,21 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
migrated_config.write_file(config_path)
return migrated_config
else:
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e

if loaded_config_dict["schema_version"] == "4.0.0":
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)

# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e


@lru_cache(maxsize=1)
Expand Down
10 changes: 5 additions & 5 deletions invokeai/app/services/model_install/model_install_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Union

import torch
import yaml
from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl
Expand Down Expand Up @@ -42,7 +43,7 @@
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.devices import choose_precision, choose_torch_device
from invokeai.backend.util.devices import TorchDevice

from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP,
Expand Down Expand Up @@ -634,11 +635,10 @@ def _next_id(self) -> int:
self._next_job_id += 1
return id

@staticmethod
def _guess_variant() -> Optional[ModelRepoVariant]:
def _guess_variant(self) -> Optional[ModelRepoVariant]:
"""Guess the best HuggingFace variant type to download."""
precision = choose_precision(choose_torch_device())
return ModelRepoVariant.FP16 if precision == "float16" else None
precision = TorchDevice.choose_torch_dtype()
return ModelRepoVariant.FP16 if precision == torch.float16 else None

def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
return ModelInstallJob(
Expand Down
8 changes: 5 additions & 3 deletions invokeai/app/services/model_manager/model_manager_default.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""

from typing import Optional

import torch
from typing_extensions import Self

from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger

from ..config import InvokeAIAppConfig
Expand Down Expand Up @@ -67,7 +69,7 @@ def build_model_manager(
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
execution_device: torch.device = choose_torch_device(),
execution_device: Optional[torch.device] = None,
) -> Self:
"""
Construct the model manager service instance.
Expand All @@ -82,7 +84,7 @@ def build_model_manager(
max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload,
logger=logger,
execution_device=execution_device,
execution_device=execution_device or TorchDevice.choose_torch_device(),
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService(
Expand Down
Loading

0 comments on commit e93f4d6

Please sign in to comment.