Skip to content

Commit

Permalink
minimal stable diffusion GPU memory usage with accelerate hooks (hugg…
Browse files Browse the repository at this point in the history
…ingface#850)

* add method to enable cuda with minimal gpu usage to stable diffusion

* add test to minimal cuda memory usage

* ensure all models but unet are onn torch.float32

* move to cpu_offload along with minor internal changes to make it work

* make it test against accelerate master branch

* coming back, its official: I don't know how to make it test againt the master branch from accelerate

* make it install accelerate from master on tests

* go back to accelerate>=0.11

* undo prettier formatting on yml files

* undo prettier formatting on yml files againn
  • Loading branch information
piEsposito authored Oct 26, 2022
1 parent b467ca6 commit ff10cc0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ def device(self) -> torch.device:
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
if module.device == torch.device("meta"):
return torch.device("cpu")
return module.device
return torch.device("cpu")

Expand Down
13 changes: 13 additions & 0 deletions pipelines/stable_diffusion/pipeline_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from diffusers.utils import is_accelerate_available
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from ...configuration_utils import FrozenDict
Expand Down Expand Up @@ -118,6 +119,18 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

def cuda_with_minimal_gpu_usage(self):
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")

device = torch.device("cuda")
self.enable_attention_slicing(1)

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
cpu_offload(cpu_offloaded_model, device)

@torch.no_grad()
def __call__(
self,
Expand Down

0 comments on commit ff10cc0

Please sign in to comment.