Skip to content

Commit

Permalink
Merge pull request #1060 from akx/refactor-xpu-init
Browse files Browse the repository at this point in the history
Deduplicate ipex initialization code
  • Loading branch information
kohya-ss authored Jan 23, 2024
2 parents 6805caf + 6f3f701 commit bea4362
Show file tree
Hide file tree
Showing 17 changed files with 70 additions and 113 deletions.
10 changes: 3 additions & 7 deletions XTI_hijack.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.ipex_interop import init_ipex

init_ipex()
from typing import Union, List, Optional, Dict, Any, Tuple
from diffusers.models.unet_2d_condition import UNet2DConditionOutput

Expand Down
9 changes: 2 additions & 7 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,10 @@
from tqdm import tqdm
import torch

try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex

if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()

ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler

Expand Down
9 changes: 2 additions & 7 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,10 @@
import numpy as np
import torch

try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex

if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()

ipex_init()
except Exception:
pass
import torchvision
from diffusers import (
AutoencoderKL,
Expand Down
24 changes: 24 additions & 0 deletions library/ipex_interop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch


def init_ipex():
"""
Try to import `intel_extension_for_pytorch`, and apply
the hijacks using `library.ipex.ipex_init`.
If IPEX is not installed, this function does nothing.
"""
try:
import intel_extension_for_pytorch as ipex # noqa
except ImportError:
return

try:
from library.ipex import ipex_init

if torch.xpu.is_available():
is_initialized, error_message = ipex_init()
if not is_initialized:
print("failed to initialize ipex:", error_message)
except Exception as e:
print("failed to initialize ipex:", e)
10 changes: 2 additions & 8 deletions library/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,9 @@
import os
import torch

try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex

if torch.xpu.is_available():
from library.ipex import ipex_init

ipex_init()
except Exception:
pass
init_ipex()
import diffusers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
Expand Down
9 changes: 2 additions & 7 deletions sdxl_gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,10 @@
import numpy as np
import torch

try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex

if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()

ipex_init()
except Exception:
pass
import torchvision
from diffusers import (
AutoencoderKL,
Expand Down
12 changes: 5 additions & 7 deletions sdxl_minimal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@
from einops import repeat
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass

from library.ipex_interop import init_ipex

init_ipex()

from tqdm import tqdm
from transformers import CLIPTokenizer
from diffusers import EulerDiscreteScheduler
Expand Down
9 changes: 2 additions & 7 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,10 @@
from tqdm import tqdm
import torch

try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex

if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()

ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import sdxl_model_util
Expand Down
12 changes: 5 additions & 7 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@

from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass

from library.ipex_interop import init_ipex

init_ipex()

from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
import accelerate
Expand Down
12 changes: 5 additions & 7 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@

from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass

from library.ipex_interop import init_ipex

init_ipex()

from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel
Expand Down
9 changes: 2 additions & 7 deletions sdxl_train_network.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import argparse
import torch

try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex

if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()

ipex_init()
except Exception:
pass
from library import sdxl_model_util, sdxl_train_util, train_util
import train_network

Expand Down
10 changes: 3 additions & 7 deletions sdxl_train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,9 @@

import regex
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.ipex_interop import init_ipex

init_ipex()
import open_clip
from library import sdxl_model_util, sdxl_train_util, train_util

Expand Down
9 changes: 2 additions & 7 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,10 @@
from tqdm import tqdm
import torch

try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex

if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()

ipex_init()
except Exception:
pass
from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
from diffusers import DDPMScheduler, ControlNetModel
Expand Down
9 changes: 2 additions & 7 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,10 @@
from tqdm import tqdm
import torch

try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex

if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()

ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler

Expand Down
9 changes: 2 additions & 7 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,10 @@
import torch
from torch.nn.parallel import DistributedDataParallel as DDP

try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex

if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()

ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import model_util
Expand Down
9 changes: 2 additions & 7 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,10 @@
from tqdm import tqdm
import torch

try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex

if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()

ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer
Expand Down
12 changes: 5 additions & 7 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@

from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass

from library.ipex_interop import init_ipex

init_ipex()

from accelerate.utils import set_seed
import diffusers
from diffusers import DDPMScheduler
Expand Down

0 comments on commit bea4362

Please sign in to comment.