Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(SD) Bump punet revision to d30d6ff and enable punet test. #774

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
a3376d9
Bump punet revision to d30d6ff
eagarvey-amd Jul 12, 2024
7cabac0
Enable punet t2i test.
eagarvey-amd Jul 12, 2024
7dfd4c8
Use formatted strings as input to printer.
eagarvey-amd Jul 12, 2024
1cd3ee9
Rework sdxl test to setup with a pipeline, fix unloading submodels, f…
eagarvey-amd Jul 12, 2024
1a90abd
Add switch for punet preprocessing flags
eagarvey-amd Jul 13, 2024
b70318d
Xfail punet e2e test.
eagarvey-amd Jul 13, 2024
2d7ebcd
Fixups to sdxl test arguments
eagarvey-amd Jul 15, 2024
feebc87
Fix flagset arg and enable vae encode.
eagarvey-amd Jul 15, 2024
af7782b
Enable VAE encode validation, mark as xfail
eagarvey-amd Jul 15, 2024
eff59a9
Fix formatting
eagarvey-amd Jul 15, 2024
63fb053
fix runner function name in old sd test.
eagarvey-amd Jul 15, 2024
aff48ab
Fix xfail syntax.
eagarvey-amd Jul 15, 2024
b10ad8d
Update unet script for compile function signature change
eagarvey-amd Jul 15, 2024
321d21d
Update punet to 4d4f955
IanNod Jul 16, 2024
2de912e
Disable vulkan test on MI250 runner.
monorimet Jul 16, 2024
9fdc07f
Change tqdm disable conditions and deepcopy model map on init.
eagarvey-amd Jul 17, 2024
b20be32
Don't break workarounds for model path
monorimet Jul 17, 2024
02705a9
Fix for passing a path as attn_spec.
eagarvey-amd Jul 18, 2024
9229aed
Bump punet revision to defeb489fe2bb17b77d587924db9e58048a8c140
eagarvey-amd Jul 19, 2024
f09ef4a
Move JIT cpu scheduling load helpers inside conditional.
eagarvey-amd Jul 19, 2024
bbcc424
formatting
eagarvey-amd Jul 19, 2024
1f19c7f
Don't pass benchmark as an export arg.
eagarvey-amd Jul 19, 2024
39c0c00
Changes so no external downloads. (#781)
saienduri Jul 19, 2024
3c59b25
fix so that we check exact paths as well for is_prepared (#782)
saienduri Jul 19, 2024
2e9de46
Update punet to 60edc91
IanNod Jul 20, 2024
aa0ac2b
Vae weight path none check (#784)
saienduri Jul 21, 2024
6556a36
Bump punet to mi300_all_sym_8_step10 (62785ea)
monorimet Jul 22, 2024
2c49cb6
Changes so that the default run without quant docker will work as wel…
saienduri Jul 22, 2024
cb911b1
Bump punet to 361df65844e0a7c766484707c57f6248cea9587f
eagarvey-amd Jul 22, 2024
d857f77
Sync flags to sdxl-scripts repo (#786)
saienduri Jul 23, 2024
37548f2
Integrate int8 tk kernels (#783)
nithinsubbiah Jul 23, 2024
25b2462
Update punet revision to deterministic version (42e9407)
monorimet Jul 23, 2024
0e57b4e
Integration of tk kernels into pipeline (#789)
saienduri Jul 24, 2024
920dbf5
Update unet horizontal fusion flag (#790)
saienduri Jul 25, 2024
6f16731
Revert "Update unet horizontal fusion flag (#790)"
saienduri Jul 25, 2024
15dbd93
[tk kernel] Add support to match kernel with number of arguments and …
nithinsubbiah Jul 25, 2024
0c02652
Add functionality to SD pipeline and abstracted components for saving…
monorimet Jul 25, 2024
3fd954b
Remove download links for tk kernels and instead specify kernel direc…
nithinsubbiah Jul 25, 2024
7f8a2b0
Update to best iteration on unet weights (#794)
saienduri Jul 25, 2024
bf63aec
Add missing tk_kernel_args arg in function calls (#795)
nithinsubbiah Jul 25, 2024
a74d98e
update hash for config file
saienduri Jul 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/test_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ jobs:

pytest -v models/turbine_models/tests/sd_test.py
pytest -v models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5
pytest -v models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux
pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default
pytest -v models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 --attn_spec default --batch_size 2
pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5
pytest -v models/turbine_models/tests/sd3_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu --num_inference_steps 5
39 changes: 35 additions & 4 deletions models/turbine_models/custom_models/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,21 @@ class PipelineComponent:
"""

def __init__(
self, printer, dest_type="devicearray", dest_dtype="float16", benchmark=False
self,
printer,
dest_type="devicearray",
dest_dtype="float16",
benchmark=False,
save_outputs=False,
):
self.runner = None
self.module_name = None
self.device = None
self.metadata = None
self.printer = printer
self.benchmark = benchmark
self.save_outputs = save_outputs
self.output_counter = 0
self.dest_type = dest_type
self.dest_dtype = dest_dtype

Expand Down Expand Up @@ -218,6 +225,16 @@ def _output_cast(self, output):
case _:
return output

def save_output(self, function_name, output):
if isinstance(output, tuple) or isinstance(output, list):
for i in output:
self.save_output(function_name, i)
else:
np.save(
f"{function_name}_output_{self.output_counter}.npy", output.to_host()
)
self.output_counter += 1

def _run(self, function_name, inputs: list):
return self.module[function_name](*inputs)

Expand All @@ -239,6 +256,8 @@ def __call__(self, function_name, inputs: list):
output = self._run_and_benchmark(function_name, inputs)
else:
output = self._run(function_name, inputs)
if self.save_outputs:
self.save_output(function_name, output)
output = self._output_cast(output)
return output

Expand Down Expand Up @@ -340,6 +359,7 @@ def __init__(
hf_model_name: str | dict[str] = None,
benchmark: bool | dict[bool] = False,
verbose: bool = False,
save_outputs: bool | dict[bool] = False,
common_export_args: dict = {},
):
self.map = model_map
Expand Down Expand Up @@ -374,6 +394,7 @@ def __init__(
"external_weights": external_weights,
"hf_model_name": hf_model_name,
"benchmark": benchmark,
"save_outputs": save_outputs,
}
for arg in map_arguments.keys():
self.map = merge_arg_into_map(self.map, map_arguments[arg], arg)
Expand All @@ -391,7 +412,8 @@ def __init__(
)
for submodel in self.map.keys():
for key, value in map_arguments.items():
self.map = merge_export_arg(self.map, value, key)
if key not in ["benchmark", "save_outputs"]:
self.map = merge_export_arg(self.map, value, key)
for key, value in self.map[submodel].get("export_args", {}).items():
if key == "hf_model_name":
self.map[submodel]["keywords"].append(
Expand Down Expand Up @@ -539,7 +561,11 @@ def is_prepared(self, vmfbs, weights):
avail_files = os.listdir(self.external_weights_dir)
candidates = []
for filename in avail_files:
if all(str(x) in filename for x in w_keywords):
if all(
str(x) in filename
or str(x) == os.path.join(self.external_weights_dir, filename)
for x in w_keywords
):
candidates.append(
os.path.join(self.external_weights_dir, filename)
)
Expand Down Expand Up @@ -723,7 +749,7 @@ def export_submodel(
def load_map(self):
for submodel in self.map.keys():
if not self.map[submodel]["load"]:
self.printer.print("Skipping load for ", submodel)
self.printer.print(f"Skipping load for {submodel}")
continue
self.load_submodel(submodel)

Expand All @@ -739,6 +765,7 @@ def load_submodel(self, submodel):
printer=self.printer,
dest_type=dest_type,
benchmark=self.map[submodel].get("benchmark", False),
save_outputs=self.map[submodel].get("save_outputs", False),
)
self.map[submodel]["runner"].load(
self.map[submodel]["driver"],
Expand All @@ -751,6 +778,10 @@ def load_submodel(self, submodel):

def unload_submodel(self, submodel):
self.map[submodel]["runner"].unload()
self.map[submodel]["vmfb"] = None
self.map[submodel]["mlir"] = None
self.map[submodel]["weights"] = None
self.map[submodel]["export_args"]["input_mlir"] = None
setattr(self, submodel, None)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ def is_valid_file(arg):
help="A comma-separated list of submodel IDs for which to report benchmarks for, or 'all' for all components.",
)

p.add_argument(
"--save_outputs",
type=str,
default=None,
help="A comma-separated list of submodel IDs for which to save output .npys for, or 'all' for all components.",
)
##############################################################################
# SDXL Modelling Options
# These options are used to control model defining parameters for SDXL.
Expand Down
72 changes: 54 additions & 18 deletions models/turbine_models/custom_models/sd_inference/sd_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def get_sd_model_map(hf_model_name):
"stabilityai/sdxl-turbo",
"stabilityai/stable-diffusion-xl-base-1.0",
"/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16/checkpoint_pipe",
"/models/SDXL/official_pytorch/fp16/stable_diffusion_fp16//checkpoint_pipe",
]:
return sdxl_model_map
elif "stabilityai/stable-diffusion-3" in name:
Expand Down Expand Up @@ -233,6 +234,12 @@ def __init__(
benchmark: bool | dict[bool] = False,
verbose: bool = False,
batch_prompts: bool = False,
punet_quant_paths: dict[str] = None,
vae_weight_path: str = None,
vae_harness: bool = True,
add_tk_kernels: bool = False,
tk_kernels_dir: str | dict[str] = None,
save_outputs: bool | dict[bool] = False,
):
common_export_args = {
"hf_model_name": None,
Expand All @@ -243,11 +250,11 @@ def __init__(
"exit_on_vmfb": False,
"pipeline_dir": pipeline_dir,
"input_mlir": None,
"attn_spec": None,
"attn_spec": attn_spec,
"external_weights": None,
"external_weight_path": None,
}
sd_model_map = get_sd_model_map(hf_model_name)
sd_model_map = copy.deepcopy(get_sd_model_map(hf_model_name))
for submodel in sd_model_map:
if "load" not in sd_model_map[submodel]:
sd_model_map[submodel]["load"] = True
Expand Down Expand Up @@ -281,6 +288,7 @@ def __init__(
hf_model_name,
benchmark,
verbose,
save_outputs,
common_export_args,
)
for submodel in sd_model_map:
Expand All @@ -303,6 +311,7 @@ def __init__(
self.cpu_scheduling = cpu_scheduling
self.scheduler_id = scheduler_id
self.num_inference_steps = num_inference_steps
self.punet_quant_paths = punet_quant_paths

self.text_encoder = None
self.unet = None
Expand All @@ -311,6 +320,8 @@ def __init__(
self.scheduler = None

self.split_scheduler = True
self.add_tk_kernels = add_tk_kernels
self.tk_kernels_dir = tk_kernels_dir

self.base_model_name = (
hf_model_name
Expand Down Expand Up @@ -339,6 +350,9 @@ def __init__(
self.scheduler_device = self.map["unet"]["device"]
self.scheduler_driver = self.map["unet"]["driver"]
self.scheduler_target = self.map["unet"]["target"]
if vae_weight_path is not None:
self.map["vae"]["export_args"]["external_weight_path"] = vae_weight_path
self.map["vae"]["export_args"]["vae_harness"] = vae_harness
elif not self.is_sd3:
self.tokenizer = CLIPTokenizer.from_pretrained(
self.base_model_name, subfolder="tokenizer"
Expand All @@ -351,23 +365,31 @@ def __init__(

self.latents_dtype = torch_dtypes[self.latents_precision]
self.use_i8_punet = self.use_punet = use_i8_punet
if self.use_punet:
self.setup_punet()
else:
self.map["unet"]["keywords"].append("!punet")
self.map["unet"]["function_name"] = "run_forward"

def setup_punet(self):
if self.use_i8_punet:
if self.add_tk_kernels:
self.map["unet"]["export_args"]["add_tk_kernels"] = self.add_tk_kernels
self.map["unet"]["export_args"]["tk_kernels_dir"] = self.tk_kernels_dir
self.map["unet"]["export_args"]["precision"] = "i8"
self.map["unet"]["export_args"]["use_punet"] = True
self.map["unet"]["use_weights_for_export"] = True
self.map["unet"]["keywords"].append("punet")
self.map["unet"]["module_name"] = "compiled_punet"
self.map["unet"]["function_name"] = "main"
self.map["unet"]["export_args"]["external_weight_path"] = (
utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa"
)
self.map["unet"]["export_args"]["quant_paths"] = self.punet_quant_paths
for idx, word in enumerate(self.map["unet"]["keywords"]):
if word in ["fp32", "fp16"]:
self.map["unet"]["keywords"][idx] = "i8"
break
else:
self.map["unet"]["keywords"].append("!punet")
self.map["unet"]["function_name"] = "run_forward"
self.map["unet"]["export_args"]["use_punet"] = True
self.map["unet"]["use_weights_for_export"] = True
self.map["unet"]["keywords"].append("punet")
self.map["unet"]["module_name"] = "compiled_punet"
self.map["unet"]["function_name"] = "main"

# LOAD

Expand All @@ -376,10 +398,6 @@ def load_scheduler(
scheduler_id: str,
steps: int = 30,
):
if self.is_sd3:
scheduler_device = self.mmdit.device
else:
scheduler_device = self.unet.device
if not self.cpu_scheduling:
self.map["scheduler"] = {
"module_name": "compiled_scheduler",
Expand Down Expand Up @@ -426,6 +444,10 @@ def load_scheduler(
print("JIT export of scheduler failed. Loading CPU scheduler.")
self.cpu_scheduling = True
if self.cpu_scheduling:
if self.is_sd3:
scheduler_device = self.mmdit.device
else:
scheduler_device = self.unet.device
scheduler = schedulers.get_scheduler(self.base_model_name, scheduler_id)
self.scheduler = schedulers.SharkSchedulerCPUWrapper(
scheduler,
Expand Down Expand Up @@ -481,9 +503,12 @@ def prepare_latents(
elif self.is_sdxl and self.cpu_scheduling:
self.scheduler.do_guidance = False
self.scheduler.repeat_sample = False
sample, add_time_ids, step_indexes, timesteps = (
self.scheduler.initialize_sdxl(noise, num_inference_steps)
)
(
sample,
add_time_ids,
step_indexes,
timesteps,
) = self.scheduler.initialize_sdxl(noise, num_inference_steps)
return sample, add_time_ids, step_indexes, timesteps
elif self.is_sdxl:
return self.scheduler("run_initialize", noise)
Expand Down Expand Up @@ -565,9 +590,11 @@ def _produce_latents_sdxl(
[guidance_scale],
dtype=self.map["unet"]["np_dtype"],
)
# Disable progress bar if we aren't in verbose mode or if we're printing
# benchmark latencies for unet.
for i, t in tqdm(
enumerate(timesteps),
disable=(self.map["unet"].get("benchmark") and self.verbose),
disable=(self.map["unet"].get("benchmark") or not self.verbose),
):
if self.cpu_scheduling:
latent_model_input, t = self.scheduler.scale_model_input(
Expand Down Expand Up @@ -720,6 +747,14 @@ def numpy_to_pil_image(images):
benchmark[i] = True
else:
benchmark = False
if args.save_outputs:
if args.save_outputs.lower() == "all":
save_outputs = True
else:
for i in args.save_outputs.split(","):
save_outputs[i] = True
else:
save_outputs = False
if any(x for x in [args.vae_decomp_attn, args.unet_decomp_attn]):
args.decomp_attn = {
"text_encoder": args.decomp_attn,
Expand Down Expand Up @@ -750,6 +785,7 @@ def numpy_to_pil_image(images):
args.use_i8_punet,
benchmark,
args.verbose,
save_outputs=save_outputs,
)
sd_pipe.prepare_all()
sd_pipe.load_map()
Expand Down
Loading
Loading