diff --git a/apps/stable_diffusion/src/utils/sd_annotation.py b/apps/stable_diffusion/src/utils/sd_annotation.py index 89c2dc3d2911b..95bef4dc23f26 100644 --- a/apps/stable_diffusion/src/utils/sd_annotation.py +++ b/apps/stable_diffusion/src/utils/sd_annotation.py @@ -1,4 +1,5 @@ import os +import io from shark.model_annotation import model_annotation, create_context from shark.iree_utils._common import iree_target_map, run_cmd from shark.shark_downloader import ( @@ -97,10 +98,15 @@ def annotate_with_winograd(input_mlir, winograd_config_dir, model_name): search_op="conv", winograd=True, ) - with open(out_file_path, "w") as f: - f.write(str(winograd_model)) - f.close() - return winograd_model, out_file_path + + bytecode_stream = io.BytesIO() + winograd_model.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + + with open(out_file_path, "w") as f: + f.write(str(winograd_model)) + f.close() + return bytecode, out_file_path def dump_after_mlir(input_mlir, model_name, use_winograd): @@ -176,10 +182,15 @@ def annotate_with_lower_configs( ) else: out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir" + + bytecode_stream = io.BytesIO() + tuned_model.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + with open(out_file_path, "w") as f: f.write(str(tuned_model)) f.close() - return tuned_model, out_file_path + return bytecode, out_file_path def sd_model_annotation(mlir_model, model_name, model_from_tank=False): @@ -215,7 +226,7 @@ def sd_model_annotation(mlir_model, model_name, model_from_tank=False): mlir_model, lowering_config_dir, model_name, use_winograd ) print(f"Saved the annotated mlir in {output_path}.") - return tuned_model, output_path + return tuned_model if __name__ == "__main__": diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py index 6acfa823fc8ac..42f2f5b372520 100644 --- a/apps/stable_diffusion/src/utils/utils.py +++ b/apps/stable_diffusion/src/utils/utils.py @@ -96,26 +96,19 @@ def compile_through_fx( ) if use_tuned: - tuned_model_path = f"{args.annotation_output}/{model_name}_torch.mlir" - if not os.path.exists(tuned_model_path): - if "vae" in model_name.split("_")[0]: - args.annotation_model = "vae" - - tuned_model, tuned_model_path = sd_model_annotation( - mlir_module, model_name - ) - del mlir_module, tuned_model - gc.collect() - - with open(tuned_model_path, "rb") as f: - mlir_module = f.read() - f.close() + if "vae" in model_name.split("_")[0]: + args.annotation_model = "vae" + mlir_module = sd_model_annotation(mlir_module, model_name) shark_module = SharkInference( mlir_module, device=args.device, mlir_dialect="linalg", ) + + del mlir_module + gc.collect() + return _compile_module(shark_module, model_name, extra_args) @@ -253,11 +246,7 @@ def set_init_device_flags(): ): args.use_tuned = False - elif "cuda" in args.device and get_cuda_sm_cc() not in [ - "sm_80", - "sm_84", - "sm_86", - ]: + elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80"]: args.use_tuned = False elif args.use_base_vae and args.hf_model_id not in [