From 9ef0293f744eb35e4e411b5d4675db35e85699ea Mon Sep 17 00:00:00 2001 From: huchenlei Date: Sun, 8 Dec 2024 20:50:20 -0500 Subject: [PATCH 1/5] Lint unused import --- .pylintrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index a5da56e57ca..27f9324329f 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,3 +1,3 @@ [MESSAGES CONTROL] disable=all -enable=eval-used +enable=eval-used unused-import From b1be8ac71c468bd853c1083c02440dac75d9a0b6 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Sun, 8 Dec 2024 20:51:20 -0500 Subject: [PATCH 2/5] nit --- .pylintrc | 2 +- styles/default.csv | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 styles/default.csv diff --git a/.pylintrc b/.pylintrc index 27f9324329f..fc0eb2259d6 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,3 +1,3 @@ [MESSAGES CONTROL] disable=all -enable=eval-used unused-import +enable=eval-used, unused-import diff --git a/styles/default.csv b/styles/default.csv new file mode 100644 index 00000000000..b7977e59745 --- /dev/null +++ b/styles/default.csv @@ -0,0 +1,3 @@ +name,prompt,negative_prompt +❌Low Token,,"embedding:EasyNegative, NSFW, Cleavage, Pubic Hair, Nudity, Naked, censored" +✅Line Art / Manga,"(Anime Scene, Toonshading, Satoshi Kon, Ken Sugimori, Hiromu Arakawa:1.2), (Anime Style, Manga Style:1.3), Low detail, sketch, concept art, line art, webtoon, manhua, hand drawn, defined lines, simple shades, minimalistic, High contrast, Linear compositions, Scalable artwork, Digital art, High Contrast Shadows, glow effects, humorous illustration, big depth of field, Masterpiece, colors, concept art, trending on artstation, Vivid colors, dramatic", From ee13a28b412a223f88860240ec221bf6d02fafea Mon Sep 17 00:00:00 2001 From: huchenlei Date: Sun, 8 Dec 2024 20:55:24 -0500 Subject: [PATCH 3/5] Remove unused imports --- comfy/cldm/cldm.py | 2 -- comfy/cldm/dit_embedder.py | 2 -- comfy/cldm/mmdit.py | 2 +- comfy/extra_samplers/uni_pc.py | 3 +-- comfy/ldm/audio/autoencoder.py | 2 +- comfy/ldm/audio/embedders.py | 4 ++-- comfy/ldm/cascade/controlnet.py | 1 - comfy/ldm/flux/controlnet.py | 4 +--- comfy/ldm/genmo/joint_model/utils.py | 2 +- comfy/ldm/genmo/vae/model.py | 2 +- comfy/ldm/hydit/controlnet.py | 7 ------- comfy/ldm/hydit/models.py | 2 -- comfy/ldm/hydit/poolers.py | 1 - comfy/ldm/lightricks/vae/causal_video_autoencoder.py | 2 +- comfy/ldm/lightricks/vae/conv_nd_factory.py | 1 - comfy/ldm/models/autoencoder.py | 2 +- comfy/ldm/modules/diffusionmodules/mmdit.py | 2 -- comfy/ldm/modules/diffusionmodules/model.py | 1 - comfy/ldm/modules/diffusionmodules/openaimodel.py | 1 - comfy/ldm/modules/diffusionmodules/upscaling.py | 1 - comfy/ldm/modules/diffusionmodules/util.py | 1 - comfy/ldm/modules/temporal_ae.py | 2 +- comfy/sampler_helpers.py | 1 - comfy/text_encoders/spiece_tokenizer.py | 1 - comfy_extras/nodes_advanced_samplers.py | 3 +-- comfy_extras/nodes_clip_sdxl.py | 1 - comfy_extras/nodes_compositing.py | 1 - comfy_extras/nodes_hooks.py | 1 - comfy_extras/nodes_model_advanced.py | 1 - comfy_extras/nodes_model_downscale.py | 1 - comfy_extras/nodes_upscale_model.py | 1 - custom_nodes/websocket_image_save.py | 4 +--- execution.py | 1 - folder_paths.py | 4 ++-- latent_preview.py | 2 -- main.py | 2 +- script_examples/basic_api_example.py | 3 +-- styles/default.csv | 3 --- tests/inference/test_inference.py | 1 - 39 files changed, 17 insertions(+), 61 deletions(-) delete mode 100644 styles/default.csv diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 9ec64a22751..05282a3be33 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -2,11 +2,9 @@ #and modified import torch -import torch as th import torch.nn as nn from ..ldm.modules.diffusionmodules.util import ( - zero_module, timestep_embedding, ) diff --git a/comfy/cldm/dit_embedder.py b/comfy/cldm/dit_embedder.py index e9cdd49910b..f9bf31012b1 100644 --- a/comfy/cldm/dit_embedder.py +++ b/comfy/cldm/dit_embedder.py @@ -1,10 +1,8 @@ import math from typing import List, Optional, Tuple -import numpy as np import torch import torch.nn as nn -from einops import rearrange from torch import Tensor from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch diff --git a/comfy/cldm/mmdit.py b/comfy/cldm/mmdit.py index 54a58ab835a..b7764085e94 100644 --- a/comfy/cldm/mmdit.py +++ b/comfy/cldm/mmdit.py @@ -1,5 +1,5 @@ import torch -from typing import Dict, Optional +from typing import Optional import comfy.ldm.modules.diffusionmodules.mmdit class ControlNet(comfy.ldm.modules.diffusionmodules.mmdit.MMDiT): diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 3ab42c6a940..39365752083 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -1,10 +1,9 @@ #code taken from: https://github.com/wl-zhao/UniPC and modified import torch -import torch.nn.functional as F import math -from tqdm.auto import trange, tqdm +from tqdm.auto import trange class NoiseScheduleVP: diff --git a/comfy/ldm/audio/autoencoder.py b/comfy/ldm/audio/autoencoder.py index 8123e66a500..21044d17f8f 100644 --- a/comfy/ldm/audio/autoencoder.py +++ b/comfy/ldm/audio/autoencoder.py @@ -2,7 +2,7 @@ import torch from torch import nn -from typing import Literal, Dict, Any +from typing import Literal import math import comfy.ops ops = comfy.ops.disable_weight_init diff --git a/comfy/ldm/audio/embedders.py b/comfy/ldm/audio/embedders.py index 82a3210c60d..20edb365aaa 100644 --- a/comfy/ldm/audio/embedders.py +++ b/comfy/ldm/audio/embedders.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn -from torch import Tensor, einsum -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from torch import Tensor +from typing import List, Union from einops import rearrange import math import comfy.ops diff --git a/comfy/ldm/cascade/controlnet.py b/comfy/ldm/cascade/controlnet.py index 7a52c3c263f..90473481a07 100644 --- a/comfy/ldm/cascade/controlnet.py +++ b/comfy/ldm/cascade/controlnet.py @@ -16,7 +16,6 @@ along with this program. If not, see . """ -import torch import torchvision from torch import nn from .common import LayerNorm2d_op diff --git a/comfy/ldm/flux/controlnet.py b/comfy/ldm/flux/controlnet.py index c033dea52f2..5322c489101 100644 --- a/comfy/ldm/flux/controlnet.py +++ b/comfy/ldm/flux/controlnet.py @@ -6,9 +6,7 @@ from torch import Tensor, nn from einops import rearrange, repeat -from .layers import (DoubleStreamBlock, EmbedND, LastLayer, - MLPEmbedder, SingleStreamBlock, - timestep_embedding) +from .layers import (timestep_embedding) from .model import Flux import comfy.ldm.common_dit diff --git a/comfy/ldm/genmo/joint_model/utils.py b/comfy/ldm/genmo/joint_model/utils.py index 411902423b4..1b399d5d212 100644 --- a/comfy/ldm/genmo/joint_model/utils.py +++ b/comfy/ldm/genmo/joint_model/utils.py @@ -1,7 +1,7 @@ #original code from https://github.com/genmoai/models under apache 2.0 license #adapted to ComfyUI -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn as nn diff --git a/comfy/ldm/genmo/vae/model.py b/comfy/ldm/genmo/vae/model.py index b68d48ae5d7..1bde0c1ed73 100644 --- a/comfy/ldm/genmo/vae/model.py +++ b/comfy/ldm/genmo/vae/model.py @@ -1,7 +1,7 @@ #original code from https://github.com/genmoai/models under apache 2.0 license #adapted to ComfyUI -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union from functools import partial import math diff --git a/comfy/ldm/hydit/controlnet.py b/comfy/ldm/hydit/controlnet.py index cd71fca31aa..e1fb45294a6 100644 --- a/comfy/ldm/hydit/controlnet.py +++ b/comfy/ldm/hydit/controlnet.py @@ -1,24 +1,17 @@ -from typing import Any, Optional import torch import torch.nn as nn -import torch.nn.functional as F -from torch.utils import checkpoint from comfy.ldm.modules.diffusionmodules.mmdit import ( - Mlp, TimestepEmbedder, PatchEmbed, - RMSNorm, ) -from comfy.ldm.modules.diffusionmodules.util import timestep_embedding from .poolers import AttentionPool import comfy.latent_formats from .models import HunYuanDiTBlock, calc_rope -from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop class HunYuanControlNet(nn.Module): diff --git a/comfy/ldm/hydit/models.py b/comfy/ldm/hydit/models.py index 88459457d10..4de60795f09 100644 --- a/comfy/ldm/hydit/models.py +++ b/comfy/ldm/hydit/models.py @@ -1,8 +1,6 @@ -from typing import Any import torch import torch.nn as nn -import torch.nn.functional as F import comfy.ops from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm diff --git a/comfy/ldm/hydit/poolers.py b/comfy/ldm/hydit/poolers.py index f5e5b406fcd..c1b878ed6b0 100644 --- a/comfy/ldm/hydit/poolers.py +++ b/comfy/ldm/hydit/poolers.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -import torch.nn.functional as F from comfy.ldm.modules.attention import optimized_attention import comfy.ops diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 33b2c2d4f18..3bd59a76ea1 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -3,7 +3,7 @@ from functools import partial import math from einops import rearrange -from typing import Any, Mapping, Optional, Tuple, Union, List +from typing import Optional, Tuple, Union from .conv_nd_factory import make_conv_nd, make_linear_nd from .pixel_norm import PixelNorm diff --git a/comfy/ldm/lightricks/vae/conv_nd_factory.py b/comfy/ldm/lightricks/vae/conv_nd_factory.py index c5f067bf09e..52df4ee22e3 100644 --- a/comfy/ldm/lightricks/vae/conv_nd_factory.py +++ b/comfy/ldm/lightricks/vae/conv_nd_factory.py @@ -1,6 +1,5 @@ from typing import Tuple, Union -import torch from .dual_conv3d import DualConv3d from .causal_conv3d import CausalConv3d diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index f5f4de28830..3eeff24e23a 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -1,6 +1,6 @@ import torch from contextlib import contextmanager -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Tuple, Union from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 6f8f506ce02..7365503f51d 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -1,5 +1,3 @@ -import logging -import math from typing import Dict, Optional, List import numpy as np diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 04eb83b2181..a60ca307b55 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn import numpy as np -from typing import Optional, Any import logging from comfy import model_management diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 3f7fee708ff..4c8d53cac9c 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -9,7 +9,6 @@ from .util import ( checkpoint, avg_pool_nd, - zero_module, timestep_embedding, AlphaBlender, ) diff --git a/comfy/ldm/modules/diffusionmodules/upscaling.py b/comfy/ldm/modules/diffusionmodules/upscaling.py index f5ac7c2f913..9dbf1fe7b93 100644 --- a/comfy/ldm/modules/diffusionmodules/upscaling.py +++ b/comfy/ldm/modules/diffusionmodules/upscaling.py @@ -4,7 +4,6 @@ from functools import partial from .util import extract_into_tensor, make_beta_schedule -from comfy.ldm.util import default class AbstractLowScaleModel(nn.Module): diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index ce14ad5e18c..9377b0737fb 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -8,7 +8,6 @@ # thanks! -import os import math import torch import torch.nn as nn diff --git a/comfy/ldm/modules/temporal_ae.py b/comfy/ldm/modules/temporal_ae.py index 2992aeafc35..000cbe8ff95 100644 --- a/comfy/ldm/modules/temporal_ae.py +++ b/comfy/ldm/modules/temporal_ae.py @@ -1,5 +1,5 @@ import functools -from typing import Callable, Iterable, Union +from typing import Iterable, Union import torch from einops import rearrange, repeat diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 1252d8a5bf6..1924a8c5510 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -1,6 +1,5 @@ from __future__ import annotations import uuid -import torch import comfy.model_management import comfy.conds import comfy.utils diff --git a/comfy/text_encoders/spiece_tokenizer.py b/comfy/text_encoders/spiece_tokenizer.py index 73739553d47..cbaa99ba5b5 100644 --- a/comfy/text_encoders/spiece_tokenizer.py +++ b/comfy/text_encoders/spiece_tokenizer.py @@ -1,4 +1,3 @@ -import os import torch class SPieceTokenizer: diff --git a/comfy_extras/nodes_advanced_samplers.py b/comfy_extras/nodes_advanced_samplers.py index 820c250ef3a..5fbb096fbf8 100644 --- a/comfy_extras/nodes_advanced_samplers.py +++ b/comfy_extras/nodes_advanced_samplers.py @@ -2,8 +2,7 @@ import comfy.utils import torch import numpy as np -from tqdm.auto import trange, tqdm -import math +from tqdm.auto import trange @torch.no_grad() diff --git a/comfy_extras/nodes_clip_sdxl.py b/comfy_extras/nodes_clip_sdxl.py index b8e241578e7..70764b230f6 100644 --- a/comfy_extras/nodes_clip_sdxl.py +++ b/comfy_extras/nodes_clip_sdxl.py @@ -1,4 +1,3 @@ -import torch from nodes import MAX_RESOLUTION class CLIPTextEncodeSDXLRefiner: diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index 48fe5e3ddc6..2f994fa11d3 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -1,4 +1,3 @@ -import numpy as np import torch import comfy.utils from enum import Enum diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index f73a0e9b0f7..d0cb6990206 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -4,7 +4,6 @@ from collections.abc import Iterable if TYPE_CHECKING: - from comfy.model_patcher import ModelPatcher from comfy.sd import CLIP import comfy.hooks diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index f085bf12fe5..e57d1d56fa7 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -1,4 +1,3 @@ -import folder_paths import comfy.sd import comfy.model_sampling import comfy.latent_formats diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py index 15ffc4c8ee6..49420dee926 100644 --- a/comfy_extras/nodes_model_downscale.py +++ b/comfy_extras/nodes_model_downscale.py @@ -1,4 +1,3 @@ -import torch import comfy.utils class PatchModelAddDownscale: diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 6ba3e404f2e..04c94834129 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -1,4 +1,3 @@ -import os import logging from spandrel import ModelLoader, ImageModelDescriptor from comfy import model_management diff --git a/custom_nodes/websocket_image_save.py b/custom_nodes/websocket_image_save.py index 09fe1bde5f6..15f87f9f561 100644 --- a/custom_nodes/websocket_image_save.py +++ b/custom_nodes/websocket_image_save.py @@ -1,7 +1,5 @@ -from PIL import Image, ImageOps -from io import BytesIO +from PIL import Image import numpy as np -import struct import comfy.utils import time diff --git a/execution.py b/execution.py index 929ef85fac4..8fff8983956 100644 --- a/execution.py +++ b/execution.py @@ -17,7 +17,6 @@ from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID from comfy_execution.validation import validate_node_input -from comfy.cli_args import args class ExecutionResult(Enum): SUCCESS = 0 diff --git a/folder_paths.py b/folder_paths.py index 0c9e9f15dae..577a7bc649e 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -4,7 +4,7 @@ import time import mimetypes import logging -from typing import Set, List, Dict, Tuple, Literal +from typing import Literal from collections.abc import Collection supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'} @@ -133,7 +133,7 @@ def get_directory_by_type(type_name: str) -> str | None: return get_input_directory() return None -def filter_files_content_types(files: List[str], content_types: Literal["image", "video", "audio"]) -> List[str]: +def filter_files_content_types(files: list[str], content_types: Literal["image", "video", "audio"]) -> list[str]: """ Example: files = os.listdir(folder_paths.get_input_directory()) diff --git a/latent_preview.py b/latent_preview.py index d60e68d5512..07f9cc68e97 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -1,7 +1,5 @@ import torch from PIL import Image -import struct -import numpy as np from comfy.cli_args import args, LatentPreviewMethod from comfy.taesd.taesd import TAESD import comfy.model_management diff --git a/main.py b/main.py index 5535622d6f1..3a2206834cd 100644 --- a/main.py +++ b/main.py @@ -87,7 +87,7 @@ def execute_script(script_path): if args.windows_standalone_build: try: - import fix_torch + pass except: pass diff --git a/script_examples/basic_api_example.py b/script_examples/basic_api_example.py index bc8ad713410..c916e6cb989 100644 --- a/script_examples/basic_api_example.py +++ b/script_examples/basic_api_example.py @@ -1,6 +1,5 @@ import json -from urllib import request, parse -import random +from urllib import request #This is the ComfyUI api prompt format. diff --git a/styles/default.csv b/styles/default.csv deleted file mode 100644 index b7977e59745..00000000000 --- a/styles/default.csv +++ /dev/null @@ -1,3 +0,0 @@ -name,prompt,negative_prompt -❌Low Token,,"embedding:EasyNegative, NSFW, Cleavage, Pubic Hair, Nudity, Naked, censored" -✅Line Art / Manga,"(Anime Scene, Toonshading, Satoshi Kon, Ken Sugimori, Hiromu Arakawa:1.2), (Anime Style, Manga Style:1.3), Low detail, sketch, concept art, line art, webtoon, manhua, hand drawn, defined lines, simple shades, minimalistic, High contrast, Linear compositions, Scalable artwork, Digital art, High Contrast Shadows, glow effects, humorous illustration, big depth of field, Masterpiece, colors, concept art, trending on artstation, Vivid colors, dramatic", diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 2e11778f22c..1db3c06fb0c 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -1,6 +1,5 @@ from copy import deepcopy from io import BytesIO -from urllib import request import numpy import os from PIL import Image From 10e08b05544652700a9ca38a710e236fffec2abf Mon Sep 17 00:00:00 2001 From: huchenlei Date: Sun, 8 Dec 2024 20:58:28 -0500 Subject: [PATCH 4/5] revert fix_torch import --- main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 3a2206834cd..ceaa9d809c4 100644 --- a/main.py +++ b/main.py @@ -86,8 +86,9 @@ def execute_script(script_path): import cuda_malloc if args.windows_standalone_build: + # TODO: Convert fix_torch to a function. try: - pass + import fix_torch # noqa: F401 except: pass From 73b26e53754ffc9842941ce6b763668de5653829 Mon Sep 17 00:00:00 2001 From: huchenlei Date: Sun, 8 Dec 2024 21:01:41 -0500 Subject: [PATCH 5/5] nit --- fix_torch.py | 36 ++++++++++++++++++++---------------- main.py | 4 ++-- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/fix_torch.py b/fix_torch.py index e350f5c7d50..4aecb23f09a 100644 --- a/fix_torch.py +++ b/fix_torch.py @@ -5,20 +5,24 @@ import logging -torch_spec = importlib.util.find_spec("torch") -for folder in torch_spec.submodule_search_locations: - lib_folder = os.path.join(folder, "lib") - test_file = os.path.join(lib_folder, "fbgemm.dll") - dest = os.path.join(lib_folder, "libomp140.x86_64.dll") - if os.path.exists(dest): - break - - with open(test_file, 'rb') as f: - contents = f.read() - if b"libomp140.x86_64.dll" not in contents: +def fix_pytorch_libomp(): + """ + Fix PyTorch libomp DLL issue on Windows by copying the correct DLL file if needed. + """ + torch_spec = importlib.util.find_spec("torch") + for folder in torch_spec.submodule_search_locations: + lib_folder = os.path.join(folder, "lib") + test_file = os.path.join(lib_folder, "fbgemm.dll") + dest = os.path.join(lib_folder, "libomp140.x86_64.dll") + if os.path.exists(dest): break - try: - mydll = ctypes.cdll.LoadLibrary(test_file) - except FileNotFoundError as e: - logging.warning("Detected pytorch version with libomp issue, patching.") - shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest) + + with open(test_file, "rb") as f: + contents = f.read() + if b"libomp140.x86_64.dll" not in contents: + break + try: + mydll = ctypes.cdll.LoadLibrary(test_file) + except FileNotFoundError as e: + logging.warning("Detected pytorch version with libomp issue, patching.") + shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest) diff --git a/main.py b/main.py index ceaa9d809c4..87ce1526b2b 100644 --- a/main.py +++ b/main.py @@ -86,9 +86,9 @@ def execute_script(script_path): import cuda_malloc if args.windows_standalone_build: - # TODO: Convert fix_torch to a function. try: - import fix_torch # noqa: F401 + from fix_torch import fix_pytorch_libomp + fix_pytorch_libomp() except: pass