Skip to content

Commit

Permalink
Modify the annotation OTF to return bytecode module (huggingface#980)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhang93 authored Feb 8, 2023
1 parent 83c69ec commit e9864cb
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
23 changes: 17 additions & 6 deletions apps/stable_diffusion/src/utils/sd_annotation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import io
from shark.model_annotation import model_annotation, create_context
from shark.iree_utils._common import iree_target_map, run_cmd
from shark.shark_downloader import (
Expand Down Expand Up @@ -97,10 +98,15 @@ def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
search_op="conv",
winograd=True,
)
with open(out_file_path, "w") as f:
f.write(str(winograd_model))
f.close()
return winograd_model, out_file_path

bytecode_stream = io.BytesIO()
winograd_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()

with open(out_file_path, "w") as f:
f.write(str(winograd_model))
f.close()
return bytecode, out_file_path


def dump_after_mlir(input_mlir, model_name, use_winograd):
Expand Down Expand Up @@ -176,10 +182,15 @@ def annotate_with_lower_configs(
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"

bytecode_stream = io.BytesIO()
tuned_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()

with open(out_file_path, "w") as f:
f.write(str(tuned_model))
f.close()
return tuned_model, out_file_path
return bytecode, out_file_path


def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
Expand Down Expand Up @@ -215,7 +226,7 @@ def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
mlir_model, lowering_config_dir, model_name, use_winograd
)
print(f"Saved the annotated mlir in {output_path}.")
return tuned_model, output_path
return tuned_model


if __name__ == "__main__":
Expand Down
27 changes: 8 additions & 19 deletions apps/stable_diffusion/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,26 +96,19 @@ def compile_through_fx(
)

if use_tuned:
tuned_model_path = f"{args.annotation_output}/{model_name}_torch.mlir"
if not os.path.exists(tuned_model_path):
if "vae" in model_name.split("_")[0]:
args.annotation_model = "vae"

tuned_model, tuned_model_path = sd_model_annotation(
mlir_module, model_name
)
del mlir_module, tuned_model
gc.collect()

with open(tuned_model_path, "rb") as f:
mlir_module = f.read()
f.close()
if "vae" in model_name.split("_")[0]:
args.annotation_model = "vae"
mlir_module = sd_model_annotation(mlir_module, model_name)

shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
)

del mlir_module
gc.collect()

return _compile_module(shark_module, model_name, extra_args)


Expand Down Expand Up @@ -253,11 +246,7 @@ def set_init_device_flags():
):
args.use_tuned = False

elif "cuda" in args.device and get_cuda_sm_cc() not in [
"sm_80",
"sm_84",
"sm_86",
]:
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80"]:
args.use_tuned = False

elif args.use_base_vae and args.hf_model_id not in [
Expand Down

0 comments on commit e9864cb

Please sign in to comment.