Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to compile StableDiffusion #1244

Open
renderless opened this issue Aug 11, 2023 · 0 comments
Open

Unable to compile StableDiffusion #1244

renderless opened this issue Aug 11, 2023 · 0 comments

Comments

@renderless
Copy link

Describe the bug
Use following sample code to compile StableDiffusion model but compiler failed at export ot through torch.jit.script stage.

To Reproduce

  1. pull latest runtime image via docker pull bladedisc/bladedisc:latest-runtime-torch1.13.1-cu116
  2. install packages as follows
diffusers==0.19.3
transformers==4.29.2
accelerate==0.21.0
  1. run example code

Example code

import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        revision="fp16",
        torch_dtype=torch.float16).to("cuda")

import torch_blade
opt_cfg = torch_blade.Config()
opt_cfg.enable_fp16 = True

batch_size = 1
seq_len = 77
width = 512
height = 512
latent_channels = 4

ww = width // 8
hh = height // 8

# text encoder
encoder_ids = torch.ones(batch_size, seq_len, dtype=torch.int)
encoder_pos = torch.ones(batch_size, seq_len, dtype=torch.int)
encoder_inputs = (encoder_ids, encoder_pos)

# unet
unet_latent = torch.rand(batch_size, hh, ww, 4)
unet_timesteps = torch.rand(batch_size)
unet_text_embeddings = torch.rand(batch_size, seq_len, pipe.unet.config.cross_attention_dim)
unet_inputs = (unet_latent, unet_timesteps, unet_text_embeddings)

# vae decoder
decoder_inputs = torch.rand(batch_size, hh, ww, 4)

# todo: add encoder_inputs, unet_inputs, decoder_inputs
with opt_cfg, torch.no_grad():
    encoder = torch_blade.optimize(pipe.text_encoder, model_inputs=encoder_inputs, allow_tracing=True)
    unet = torch_blade.optimize(pipe.unet, model_inputs=unet_inputs, allow_tracing=True)
    decoder = torch_blade.optimize(pipe.vae.decoder, model_inputs=decoder_inputs, allow_tracing=True)
    

Expected behavior
torch_blade.optimize should run without problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant