diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 60225ff6..2456ae4b 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -237,6 +237,7 @@ def __init__( punet_quant_paths: dict[str] = None, vae_weight_path: str = None, vae_harness: bool = False, + add_tk_kernels: bool = False, ): common_export_args = { "hf_model_name": None, @@ -316,6 +317,7 @@ def __init__( self.scheduler = None self.split_scheduler = True + self.add_tk_kernels = add_tk_kernels self.base_model_name = ( hf_model_name @@ -367,6 +369,8 @@ def __init__( def setup_punet(self): if self.use_i8_punet: + if self.add_tk_kernels: + self.map["unet"]["export_args"]["add_tk_kernels"] = self.add_tk_kernels self.map["unet"]["export_args"]["precision"] = "i8" self.map["unet"]["export_args"]["external_weight_path"] = ( utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa" diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 0300c790..9d5c149a 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -155,12 +155,15 @@ def iree_backend_map(device): return iree_device -def replace_with_tk_kernels( - flow_dialect_ir, -): - kernels = [ - "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/tk_int8/int8-model/tk_kernels/tk_gemm_fused_2x1024x10240x1280.mlir" - ] +def replace_with_tk_kernels(flow_dialect_ir, batch_size): + if batch_size == 8: + kernels = [ + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/tk_gemm_fused_16x1024x10240x1280.mlir" + ] + if batch_size == 1: + kernels = [ + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/tk_gemm_fused_2x1024x10240x1280.mlir" + ] # Replace all calls to old kernel with new kernel print("Inserting kernels and updating calls to kernels...") @@ -235,7 +238,10 @@ def compile_to_vmfb( flagset_keywords=[], debug=False, add_tk_kernels=False, + batch_size=1, ): + if batch_size != 1 and batch_size != 8: + add_tk_kernels = False flags = [] if mlir_source == "file" and not isinstance(module_str, str): module_str = str(module_str) @@ -393,7 +399,7 @@ def compile_to_vmfb( flow_ir = flatbuffer_blob.decode("utf-8") - flow_ir_tk = replace_with_tk_kernels(flow_ir) + flow_ir_tk = replace_with_tk_kernels(flow_ir, batch_size) module_str = "\n".join(flow_ir_tk) flags.pop() flags.extend(["--compile-from=flow"]) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py index acccc391..4ed874e2 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -370,6 +370,7 @@ class CompiledUnet(CompiledModule): attn_spec=attn_spec, flagset_keywords=["punet"] if use_punet else [], add_tk_kernels=add_tk_kernels, + batch_size=batch_size, ) if exit_on_vmfb: exit()