Skip to content

Commit

Permalink
support cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
illeatmyhat authored and illeatmyhat committed Aug 24, 2022
1 parent 69ae4b3 commit 9b14b52
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 39 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__pycache__
outputs/
*.egg-info/
7 changes: 4 additions & 3 deletions ldm/models/diffusion/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.util import torch_device

from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
extract_into_tensor
Expand All @@ -15,11 +15,12 @@ def __init__(self, model, schedule="linear", **kwargs):
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device_available = torch_device.type

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
if attr.device != torch.device(self.device_available):
attr = attr.to(torch.float32).to(torch.device(self.device_available))
setattr(self, name, attr)

def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
Expand Down
7 changes: 4 additions & 3 deletions ldm/models/diffusion/plms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.util import torch_device

from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like

Expand All @@ -14,11 +14,12 @@ def __init__(self, model, schedule="linear", **kwargs):
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device_available = torch_device.type

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
if attr.device != torch.device(self.device_available):
attr = attr.to(torch.float32).to(torch.device(self.device_available))
setattr(self, name, attr)

def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
Expand Down
19 changes: 10 additions & 9 deletions ldm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel
import kornia

from ldm.util import torch_device
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test


class AbstractEncoder(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -35,7 +35,7 @@ def forward(self, batch, key=None):

class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device=torch_device):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
Expand All @@ -52,7 +52,7 @@ def encode(self, x):

class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
def __init__(self, device=torch_device, vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
Expand Down Expand Up @@ -80,7 +80,7 @@ def decode(self, text):
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
device=torch_device,use_tokenizer=True, embedding_dropout=0.0):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
Expand Down Expand Up @@ -134,9 +134,10 @@ def forward(self,x):
def encode(self, x):
return self(x)


class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
def __init__(self, version="openai/clip-vit-large-patch14", device=torch_device, max_length=77):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
Expand Down Expand Up @@ -166,7 +167,7 @@ class FrozenCLIPTextEmbedder(nn.Module):
"""
Uses the CLIP transformer encoder for text.
"""
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
def __init__(self, version='ViT-L/14', device=torch_device, max_length=77, n_repeat=1, normalize=True):
super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu")
self.device = device
Expand Down Expand Up @@ -202,7 +203,7 @@ def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
device=torch_device,
antialias=False,
):
super().__init__()
Expand Down Expand Up @@ -231,4 +232,4 @@ def forward(self, x):
if __name__ == "__main__":
from ldm.util import count_params
model = FrozenCLIPEmbedder()
count_params(model, verbose=True)
count_params(model, verbose=True)
16 changes: 16 additions & 0 deletions ldm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@
from PIL import Image, ImageDraw, ImageFont


# noinspection PyBroadException
def get_device():
try:
if torch.cuda.is_available():
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
except Exception:
return 'cpu'


torch_device = torch.device(get_device())


def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
Expand Down
20 changes: 11 additions & 9 deletions scripts/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import time
from pytorch_lightning import seed_everything

from ldm.util import instantiate_from_config
from ldm.util import instantiate_from_config, torch_device as device
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

Expand All @@ -40,7 +40,7 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)

model.cuda()
model.to(device)
model.eval()
return model

Expand Down Expand Up @@ -170,13 +170,13 @@ def main():
parser.add_argument(
"--config",
type=str,
default="configs/stable-diffusion/v1-inference.yaml",
default="configs/stable_diffusion/v1-inference.yaml",
help="path to config which constructs model",
)
parser.add_argument(
"--ckpt",
type=str,
default="models/ldm/stable-diffusion-v1/model.ckpt",
default="models/ldm/stable_diffusion-v1/model.ckpt",
help="path to checkpoint of model",
)
parser.add_argument(
Expand All @@ -199,7 +199,6 @@ def main():
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)

if opt.plms:
Expand Down Expand Up @@ -230,7 +229,7 @@ def main():
grid_count = len(os.listdir(outpath)) - 1

assert os.path.isfile(opt.init_img)
init_image = load_img(opt.init_img).to(device)
init_image = load_img(opt.init_img).to(torch_device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space

Expand All @@ -240,9 +239,12 @@ def main():
t_enc = int(opt.strength * opt.ddim_steps)
print(f"target t_enc is {t_enc} steps")

precision_scope = autocast if opt.precision == "autocast" else nullcontext
if torch_device.type in ['mps', 'cpu']:
precision_scope = nullcontext # have to use f32 on mps
else:
precision_scope = autocast if opt.precision == "autocast" else nullcontext
with torch.no_grad():
with precision_scope("cuda"):
with precision_scope(torch_device.type):
with model.ema_scope():
tic = time.time()
all_samples = list()
Expand All @@ -256,7 +258,7 @@ def main():
c = model.get_learned_conditioning(prompts)

# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(torch_device))
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,)
Expand Down
36 changes: 21 additions & 15 deletions scripts/txt2img.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
import argparse, os, sys, glob
import cv2
import torch
import argparse
import os
import time
import intel_extension_for_pytorch as ipex

import numpy as np
from omegaconf import OmegaConf
import torch

from contextlib import nullcontext
from PIL import Image
from tqdm import tqdm, trange
from imwatermark import WatermarkEncoder
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
from omegaconf import OmegaConf
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
from torchvision.utils import make_grid
from tqdm import tqdm, trange

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.util import instantiate_from_config, torch_device as device


from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
Expand Down Expand Up @@ -60,7 +65,7 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)

model.cuda()
model.to(device)
model.eval()
return model

Expand Down Expand Up @@ -238,9 +243,7 @@ def main():

config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
model = ipex.optimize(model)

if opt.plms:
sampler = PLMSSampler(model)
Expand Down Expand Up @@ -277,13 +280,16 @@ def main():
if opt.fixed_code:
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)

precision_scope = autocast if opt.precision=="autocast" else nullcontext
if device.type in ['mps', 'cpu']:
precision_scope = nullcontext # have to use f32 on mps
else:
precision_scope = autocast if opt.precision == "autocast" else nullcontext
with torch.no_grad():
with precision_scope("cuda"):
with precision_scope(device.type):
with model.ema_scope():
tic = time.time()
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"):
for _ in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
uc = None
if opt.scale != 1.0:
Expand Down

0 comments on commit 9b14b52

Please sign in to comment.