Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Accelerate model loading] Fix meta device and super low memory usage #1016

Merged
merged 2 commits into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,13 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

def cuda_with_minimal_gpu_usage(self):
def enable_sequential_cpu_offload(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great name choice!

if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")

device = torch.device("cuda")
self.enable_attention_slicing(1)

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
cpu_offload(cpu_offloaded_model, device)
Expand Down
37 changes: 37 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import gc
import random
import time
import unittest

import numpy as np
Expand Down Expand Up @@ -730,3 +731,39 @@ def test_callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> No
)
assert test_callback_fn.has_been_called
assert number_of_steps == 51

def test_stable_diffusion_accelerate_auto_device(self):
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
pipeline_id = "CompVis/stable-diffusion-v1-4"

start_time = time.time()
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
)
pipeline_normal_load.to(torch_device)
normal_load_time = time.time() - start_time

start_time = time.time()
_ = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
)
meta_device_load_time = time.time() - start_time

assert 2 * meta_device_load_time < normal_load_time
Comment on lines +736 to +751
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool!


@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

pipeline_id = "CompVis/stable-diffusion-v1-4"
prompt = "Andromeda galaxy in a bottle"

pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16)
pipeline.enable_attention_slicing(1)
pipeline.enable_sequential_cpu_offload()

_ = pipeline(prompt, num_inference_steps=5)

mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 1.5 GB is allocated
assert mem_bytes < 1.5 * 10**9
74 changes: 1 addition & 73 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@
import os
import random
import tempfile
import tracemalloc
import unittest

import numpy as np
import torch

import accelerate
import PIL
import transformers
from diffusers import (
AutoencoderKL,
DDIMPipeline,
Expand All @@ -44,8 +41,7 @@
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
from packaging import version
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer

Expand Down Expand Up @@ -487,71 +483,3 @@ def test_ddpm_ddim_equality_batched(self):

# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_images - ddim_images).max() < 1e-1

@require_torch_gpu
def test_stable_diffusion_accelerate_load_works(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test doesn't do anything so let's delete it

if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
return

if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
return

model_id = "CompVis/stable-diffusion-v1-4"
_ = StableDiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
).to(torch_device)

@require_torch_gpu
def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self):
if version.parse(version.parse(transformers.__version__).base_version) < version.parse("4.23"):
return

if version.parse(version.parse(accelerate.__version__).base_version) < version.parse("0.14"):
return

pipeline_id = "CompVis/stable-diffusion-v1-4"

torch.cuda.empty_cache()
gc.collect()

tracemalloc.start()
pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True
)
pipeline_normal_load.to(torch_device)
_, peak_normal = tracemalloc.get_traced_memory()
tracemalloc.stop()

del pipeline_normal_load
torch.cuda.empty_cache()
gc.collect()

tracemalloc.start()
_ = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, use_auth_token=True, device_map="auto"
)
_, peak_accelerate = tracemalloc.get_traced_memory()

tracemalloc.stop()

assert peak_accelerate < peak_normal

@slow
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

pipeline_id = "CompVis/stable-diffusion-v1-4"
prompt = "Andromeda galaxy in a bottle"

pipeline = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float32, use_auth_token=True
)
pipeline.cuda_with_minimal_gpu_usage()

_ = pipeline(prompt)

mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 0.8 GB is allocated
assert mem_bytes < 0.8 * 10**9