From 6f3f701d3d6a4c1386ae770a3bbe997a14bbdef6 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Thu, 18 Jan 2024 18:23:21 +0200 Subject: [PATCH] Deduplicate ipex initialization code --- XTI_hijack.py | 10 +++------- fine_tune.py | 9 ++------- gen_img_diffusers.py | 9 ++------- library/ipex_interop.py | 24 ++++++++++++++++++++++++ library/model_util.py | 10 ++-------- sdxl_gen_img.py | 9 ++------- sdxl_minimal_inference.py | 12 +++++------- sdxl_train.py | 9 ++------- sdxl_train_control_net_lllite.py | 12 +++++------- sdxl_train_control_net_lllite_old.py | 12 +++++------- sdxl_train_network.py | 9 ++------- sdxl_train_textual_inversion.py | 10 +++------- train_controlnet.py | 9 ++------- train_db.py | 9 ++------- train_network.py | 9 ++------- train_textual_inversion.py | 9 ++------- train_textual_inversion_XTI.py | 12 +++++------- 17 files changed, 70 insertions(+), 113 deletions(-) create mode 100644 library/ipex_interop.py diff --git a/XTI_hijack.py b/XTI_hijack.py index ec0849455..1dbc263ac 100644 --- a/XTI_hijack.py +++ b/XTI_hijack.py @@ -1,11 +1,7 @@ import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.ipex_interop import init_ipex + +init_ipex() from typing import Union, List, Optional, Dict, Any, Tuple from diffusers.models.unet_2d_condition import UNet2DConditionOutput diff --git a/fine_tune.py b/fine_tune.py index be61b3d16..982dc8aec 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -11,15 +11,10 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index be43847a6..a207ad5a1 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -66,15 +66,10 @@ import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass import torchvision from diffusers import ( AutoencoderKL, diff --git a/library/ipex_interop.py b/library/ipex_interop.py new file mode 100644 index 000000000..6fe320c57 --- /dev/null +++ b/library/ipex_interop.py @@ -0,0 +1,24 @@ +import torch + + +def init_ipex(): + """ + Try to import `intel_extension_for_pytorch`, and apply + the hijacks using `library.ipex.ipex_init`. + + If IPEX is not installed, this function does nothing. + """ + try: + import intel_extension_for_pytorch as ipex # noqa + except ImportError: + return + + try: + from library.ipex import ipex_init + + if torch.xpu.is_available(): + is_initialized, error_message = ipex_init() + if not is_initialized: + print("failed to initialize ipex:", error_message) + except Exception as e: + print("failed to initialize ipex:", e) diff --git a/library/model_util.py b/library/model_util.py index 1f40ce324..4361b4994 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -5,15 +5,9 @@ import os import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - - ipex_init() -except Exception: - pass +init_ipex() import diffusers from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index ab5399842..0db9e340e 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -18,15 +18,10 @@ import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass import torchvision from diffusers import ( AutoencoderKL, diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 45b9edd65..15a70678f 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -9,13 +9,11 @@ from einops import repeat import numpy as np import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from tqdm import tqdm from transformers import CLIPTokenizer from diffusers import EulerDiscreteScheduler diff --git a/sdxl_train.py b/sdxl_train.py index b4ce2770e..a3f6f3a17 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -11,15 +11,10 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import sdxl_model_util diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 4436dd3cd..7a88feb84 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -14,13 +14,11 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed import accelerate diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 6ae5377ba..b94bf5c1b 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -11,13 +11,11 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel diff --git a/sdxl_train_network.py b/sdxl_train_network.py index d810ce7d4..5d363280d 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,15 +1,10 @@ import argparse import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from library import sdxl_model_util, sdxl_train_util, train_util import train_network diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index f8a1d7bce..df3937135 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -3,13 +3,9 @@ import regex import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass +from library.ipex_interop import init_ipex + +init_ipex() import open_clip from library import sdxl_model_util, sdxl_train_util, train_util diff --git a/train_controlnet.py b/train_controlnet.py index cc0eaab7a..7b0b2bbfe 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -12,15 +12,10 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from torch.nn.parallel import DistributedDataParallel as DDP from accelerate.utils import set_seed from diffusers import DDPMScheduler, ControlNetModel diff --git a/train_db.py b/train_db.py index 14d9dff13..888cad25e 100644 --- a/train_db.py +++ b/train_db.py @@ -12,15 +12,10 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler diff --git a/train_network.py b/train_network.py index c2b7fbdef..e1ff20c33 100644 --- a/train_network.py +++ b/train_network.py @@ -14,15 +14,10 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import model_util diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 0e3912b1d..ccf7596d7 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -8,15 +8,10 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex +from library.ipex_interop import init_ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init +init_ipex() - ipex_init() -except Exception: - pass from accelerate.utils import set_seed from diffusers import DDPMScheduler from transformers import CLIPTokenizer diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 71b43549d..7046a4808 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -8,13 +8,11 @@ from tqdm import tqdm import torch -try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - from library.ipex import ipex_init - ipex_init() -except Exception: - pass + +from library.ipex_interop import init_ipex + +init_ipex() + from accelerate.utils import set_seed import diffusers from diffusers import DDPMScheduler