From ff10cc0db0522639f2bb0c41c9bf5061e69bfa83 Mon Sep 17 00:00:00 2001 From: Pi Esposito Date: Wed, 26 Oct 2022 10:52:57 -0300 Subject: [PATCH] minimal stable diffusion GPU memory usage with accelerate hooks (#850) * add method to enable cuda with minimal gpu usage to stable diffusion * add test to minimal cuda memory usage * ensure all models but unet are onn torch.float32 * move to cpu_offload along with minor internal changes to make it work * make it test against accelerate master branch * coming back, its official: I don't know how to make it test againt the master branch from accelerate * make it install accelerate from master on tests * go back to accelerate>=0.11 * undo prettier formatting on yml files * undo prettier formatting on yml files againn --- pipeline_utils.py | 2 ++ .../stable_diffusion/pipeline_stable_diffusion.py | 13 +++++++++++++ 2 files changed, 15 insertions(+) diff --git a/pipeline_utils.py b/pipeline_utils.py index d307bd5a076e..c9c58a748831 100644 --- a/pipeline_utils.py +++ b/pipeline_utils.py @@ -223,6 +223,8 @@ def device(self) -> torch.device: for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): + if module.device == torch.device("meta"): + return torch.device("cpu") return module.device return torch.device("cpu") diff --git a/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 02a6b45fdefc..cf4c5c5fdeca 100644 --- a/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -3,6 +3,7 @@ import torch +from diffusers.utils import is_accelerate_available from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict @@ -118,6 +119,18 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) + def cuda_with_minimal_gpu_usage(self): + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device("cuda") + self.enable_attention_slicing(1) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + cpu_offload(cpu_offloaded_model, device) + @torch.no_grad() def __call__( self,