Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SD PNDMScheduler + Unet example through Turbine #403

Merged
merged 15 commits into from
Feb 17, 2024
Merged
1 change: 1 addition & 0 deletions core/shark_turbine/dynamo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
torch.ops.aten._log_softmax_backward_data,
torch.ops.aten.lift_fresh_copy.default,
torch.ops.aten._unsafe_index.Tensor,
torch.ops.aten.unbind.int,
# decompositions added manually in this file
torch.ops.aten._scaled_dot_product_flash_attention.default,
]
Expand Down
178 changes: 178 additions & 0 deletions models/turbine_models/custom_models/sd_inference/schedulers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os
import sys

import torch
from torch.fx.experimental.proxy_tensor import make_fx
from shark_turbine.aot import *
from iree import runtime as ireert
import iree.compiler as ireec
from iree.compiler.ir import Context
import numpy as np

from turbine_models.custom_models.sd_inference import utils
from diffusers import (
UNet2DConditionModel,
)

import safetensors
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
"--hf_auth_token", type=str, help="The Hugging Face auth token, required"
)
parser.add_argument(
"--hf_model_name",
type=str,
help="HF model name",
default="CompVis/stable-diffusion-v1-4",
)
parser.add_argument(
"--scheduler_id",
type=str,
help="Scheduler ID",
default="PNDM",
)
parser.add_argument(
"--num_inference_steps", type=int, default=50, help="Number of inference steps"
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for inference"
)
parser.add_argument(
"--height", type=int, default=512, help="Height of Stable Diffusion"
)
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb")
parser.add_argument("--external_weight_path", type=str, default="")
parser.add_argument(
"--external_weights",
type=str,
default=None,
help="saves ir/vmfb without global weights for size and readability, options [safetensors]",
)
parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm")
# TODO: Bring in detection for target triple
parser.add_argument(
"--iree_target_triple",
type=str,
default="",
help="Specify vulkan target triple or rocm/cuda target device.",
)
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")


class Scheduler(torch.nn.Module):
def __init__(self, hf_model_name, num_inference_steps, scheduler):
super().__init__()
self.scheduler = scheduler
self.scheduler.set_timesteps(num_inference_steps)
self.unet = UNet2DConditionModel.from_pretrained(
hf_model_name,
subfolder="unet",
)
self.guidance_scale = 7.5

def forward(self, latents, encoder_hidden_states) -> torch.FloatTensor:
latents = latents * self.scheduler.init_noise_sigma
for t in self.scheduler.timesteps:
latent_model_input = torch.cat([latents] * 2)
t = t.unsqueeze(0)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, timestep=t
)
unet_out = self.unet.forward(
latent_model_input, t, encoder_hidden_states, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
return latents


def export_scheduler(
scheduler,
hf_model_name,
batch_size,
height,
width,
hf_auth_token=None,
compile_to="torch",
external_weights=None,
external_weight_path=None,
device=None,
target_triple=None,
max_alloc=None,
):
mapper = {}
utils.save_external_weights(
mapper, scheduler, external_weights, external_weight_path
)

encoder_hidden_states_sizes = (2, 77, 768)
if hf_model_name == "stabilityai/stable-diffusion-2-1-base":
encoder_hidden_states_sizes = (2, 77, 1024)

sample = (batch_size, 4, height // 8, width // 8)

class CompiledScheduler(CompiledModule):
if external_weights:
params = export_parameters(
scheduler, external=True, external_scope="", name_mapper=mapper.get
)
else:
params = export_parameters(scheduler)

def main(
self,
sample=AbstractTensor(*sample, dtype=torch.float32),
encoder_hidden_states=AbstractTensor(
*encoder_hidden_states_sizes, dtype=torch.float32
),
):
return jittable(scheduler.forward)(sample, encoder_hidden_states)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledScheduler(context=Context(), import_to=import_to)

module_str = str(CompiledModule.get_mlir_module(inst))
safe_name = utils.create_safe_name(hf_model_name, "-scheduler")
if compile_to != "vmfb":
return module_str
else:
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)


if __name__ == "__main__":
args = parser.parse_args()
schedulers = utils.get_schedulers(args.hf_model_name)
scheduler = schedulers[args.scheduler_id]
scheduler_module = Scheduler(
args.hf_model_name, args.num_inference_steps, scheduler
)
mod_str = export_scheduler(
scheduler_module,
args.hf_model_name,
args.batch_size,
args.height,
args.width,
args.hf_auth_token,
args.compile_to,
args.external_weights,
args.external_weight_path,
args.device,
args.iree_target_triple,
args.vulkan_max_allocation,
)
safe_name = utils.create_safe_name(args.hf_model_name, "-scheduler")
with open(f"{safe_name}.mlir", "w+") as f:
f.write(mod_str)
print("Saved to", safe_name + ".mlir")
172 changes: 172 additions & 0 deletions models/turbine_models/custom_models/sd_inference/schedulers_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import argparse
from turbine_models.model_runner import vmfbRunner
from iree import runtime as ireert
import torch
from diffusers import (
PNDMScheduler,
UNet2DConditionModel,
)

parser = argparse.ArgumentParser()

# TODO move common runner flags to generic flag file
parser.add_argument(
"--scheduler_id",
type=str,
help="Scheduler ID",
default="PNDM",
)
parser.add_argument(
"--num_inference_steps", type=int, default=50, help="Number of inference steps"
)
parser.add_argument(
"--vmfb_path", type=str, default="", help="path to vmfb containing compiled module"
)
parser.add_argument(
"--external_weight_path",
type=str,
default="",
help="path to external weight parameters if model compiled without them",
)
parser.add_argument(
"--compare_vs_torch",
action="store_true",
help="Runs both turbine vmfb and a torch model to compare results",
)
parser.add_argument(
"--hf_model_name",
type=str,
help="HF model name",
default="CompVis/stable-diffusion-v1-4",
)
parser.add_argument(
"--hf_auth_token",
type=str,
help="The Hugging face auth token, required for some models",
)
parser.add_argument(
"--device",
type=str,
default="local-task",
help="local-sync, local-task, cuda, vulkan, rocm",
)
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for inference"
)
parser.add_argument(
"--height", type=int, default=512, help="Height of Stable Diffusion"
)
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")


def run_scheduler(
device,
sample,
encoder_hidden_states,
vmfb_path,
hf_model_name,
hf_auth_token,
external_weight_path,
):
runner = vmfbRunner(device, vmfb_path, external_weight_path)

inputs = [
ireert.asdevicearray(runner.config.device, sample),
ireert.asdevicearray(runner.config.device, encoder_hidden_states),
]
results = runner.ctx.modules.compiled_scheduler["main"](*inputs)
return results


def run_torch_scheduler(
hf_model_name, scheduler, num_inference_steps, sample, encoder_hidden_states
):
class Scheduler(torch.nn.Module):
def __init__(self, hf_model_name, num_inference_steps, scheduler):
super().__init__()
self.scheduler = scheduler
self.scheduler.set_timesteps(num_inference_steps)
self.unet = UNet2DConditionModel.from_pretrained(
hf_model_name,
subfolder="unet",
)
self.guidance_scale = 7.5

def forward(self, latents, encoder_hidden_states) -> torch.FloatTensor:
latents = latents * self.scheduler.init_noise_sigma
for t in self.scheduler.timesteps:
latent_model_input = torch.cat([latents] * 2)
t = t.unsqueeze(0)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, timestep=t
)
unet_out = self.unet.forward(
latent_model_input, t, encoder_hidden_states, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
latents = self.scheduler.step(
noise_pred, t, latents, return_dict=False
)[0]
return latents

scheduler_module = Scheduler(hf_model_name, num_inference_steps, scheduler)
results = scheduler_module.forward(sample, encoder_hidden_states)
np_torch_output = results.detach().cpu().numpy()
return np_torch_output


if __name__ == "__main__":
args = parser.parse_args()
sample = torch.rand(
args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32
)
if args.hf_model_name == "CompVis/stable-diffusion-v1-4":
encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32)
elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base":
encoder_hidden_states = torch.rand(2, 77, 1024, dtype=torch.float32)

turbine_output = run_scheduler(
args.device,
sample,
encoder_hidden_states,
args.vmfb_path,
args.hf_model_name,
args.hf_auth_token,
args.external_weight_path,
)
print(
"TURBINE OUTPUT:",
turbine_output.to_host(),
turbine_output.to_host().shape,
turbine_output.to_host().dtype,
)

if args.compare_vs_torch:
print("generating torch output: ")
from turbine_models.custom_models.sd_inference import utils

schedulers = utils.get_schedulers(args.hf_model_name)
scheduler = schedulers[args.scheduler_id]
torch_output = run_torch_scheduler(
args.hf_model_name,
scheduler,
args.num_inference_steps,
sample,
encoder_hidden_states,
)
print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype)
err = utils.largest_error(torch_output, turbine_output)
print("Largest Error: ", err)
assert err < 9e-3

# TODO: Figure out why we occasionally segfault without unlinking output variables
turbine_output = None
22 changes: 22 additions & 0 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import numpy as np
import safetensors
import re
from diffusers import (
PNDMScheduler,
)


def save_external_weights(
Expand Down Expand Up @@ -35,6 +38,7 @@ def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name):
"--iree-llvmcpu-target-triple=x86_64-linux-gnu",
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
"--iree-flow-inline-constants-max-byte-length=1",
]
if device == "cpu":
flags.append("--iree-llvmcpu-enable-ukernels=all")
Expand Down Expand Up @@ -86,3 +90,21 @@ def create_safe_name(hf_model_name, model_name_str):
safe_name = hf_model_name.split("/")[-1].strip() + model_name_str
safe_name = re.sub("-", "_", safe_name)
return safe_name


def get_schedulers(model_id):
# TODO: Robust scheduler setup on pipeline creation -- if we don't
# set batch_size here, the SHARK schedulers will
# compile with batch size = 1 regardless of whether the model
# outputs latents of a larger batch size, e.g. SDXL.
# However, obviously, searching for whether the base model ID
# contains "xl" is not very robust.

batch_size = 2 if "xl" in model_id.lower() else 1

schedulers = dict()
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
return schedulers
Loading
Loading