Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize model library #5841

Merged
merged 10 commits into from
Dec 11, 2024
1 change: 1 addition & 0 deletions model_filemanager/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# model_manager/__init__.py
from .download_models import download_model, DownloadModelStatus, DownloadStatusType, create_model_path, check_file_exists, track_download_progress, validate_filename
from .model_filemanager import ModelFileManager
179 changes: 179 additions & 0 deletions model_filemanager/model_filemanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import os
import time
import logging
import folder_paths
import glob
from aiohttp import web
from PIL import Image
from io import BytesIO
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types


class ModelFileManager:
def __init__(self) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}

def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
return self.cache.get(key, default)

def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
self.cache[key] = value

def clear_cache(self):
self.cache.clear()

def add_routes(self, routes):
@routes.get("/models")
def list_model_types(request):
model_types = list(folder_paths.folder_names_and_paths.keys())

return web.json_response(model_types)

@routes.get("/models/{folder}")
async def get_models(request):
folder = request.match_info.get("folder", None)
if not folder in folder_paths.folder_names_and_paths:
return web.Response(status=404)
files = folder_paths.get_filename_list(folder)
return web.json_response(files)

# NOTE: This is an experiment to replace `/models`
@routes.get("/experiment/models")
async def get_model_folders(request):
model_types = list(folder_paths.folder_names_and_paths.keys())
folder_black_list = ["configs", "custom_nodes"]
output_folders: list[dict] = []
for folder in model_types:
if folder in folder_black_list:
continue
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
return web.json_response(output_folders)

# NOTE: This is an experiment to replace `/models/{folder}`
@routes.get("/experiment/models/{folder}")
async def get_all_models(request):
folder = request.match_info.get("folder", None)
if not folder in folder_paths.folder_names_and_paths:
return web.Response(status=404)
files = self.get_model_file_list(folder)
return web.json_response(files)

@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
async def get_model_preview(request):
folder_name = request.match_info.get("folder", None)
path_index = int(request.match_info.get("path_index", None))
filename = request.match_info.get("filename", None)

if not folder_name in folder_paths.folder_names_and_paths:
return web.Response(status=404)

folders = folder_paths.folder_names_and_paths[folder_name]
folder = folders[0][path_index]
full_filename = os.path.join(folder, filename)

preview_files = self.get_model_previews(full_filename)
default_preview_file = preview_files[0] if len(preview_files) > 0 else None
if default_preview_file is None or not os.path.isfile(default_preview_file):
return web.Response(status=404)

try:
with Image.open(default_preview_file) as img:
img_bytes = BytesIO()
img.save(img_bytes, format="WEBP")
img_bytes.seek(0)
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
except:
return web.Response(status=404)

def get_model_file_list(self, folder_name: str):
folder_name = map_legacy(folder_name)
folders = folder_paths.folder_names_and_paths[folder_name]
output_list: list[dict] = []

for index, folder in enumerate(folders[0]):
if not os.path.isdir(folder):
continue
out = self.cache_model_file_list_(folder)
if out is None:
out = self.recursive_search_models_(folder, index)
self.set_cache(folder, out)
output_list.extend(out[0])

return output_list

def cache_model_file_list_(self, folder: str):
model_file_list_cache = self.get_cache(folder)

if model_file_list_cache is None:
return None
if not os.path.isdir(folder):
return None
if os.path.getmtime(folder) != model_file_list_cache[1]:
return None
for x in model_file_list_cache[1]:
time_modified = model_file_list_cache[1][x]
folder = x
if os.path.getmtime(folder) != time_modified:
return None

return model_file_list_cache

def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
if not os.path.isdir(directory):
return [], {}, time.perf_counter()

excluded_dir_names = [".git"]
# TODO use settings
include_hidden_files = False

result: list[str] = []
dirs: dict[str, float] = {}

for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
if not include_hidden_files:
subdirs[:] = [d for d in subdirs if not d.startswith(".")]
filenames = [f for f in filenames if not f.startswith(".")]

filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)

for file_name in filenames:
try:
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
result.append(relative_path)
except:
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
continue

for d in subdirs:
path: str = os.path.join(dirpath, d)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue

return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()

def get_model_previews(self, filepath: str) -> list[str]:
dirname = os.path.dirname(filepath)

if not os.path.exists(dirname):
return []

basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")

result: list[str] = []

for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
return result

def __exit__(self, exc_type, exc_value, traceback):
self.clear_cache()
18 changes: 3 additions & 15 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import node_helpers
from app.frontend_management import FrontendManager
from app.user_manager import UserManager
from model_filemanager import download_model, DownloadModelStatus
from model_filemanager import download_model, DownloadModelStatus, ModelFileManager
from typing import Optional
from api_server.routes.internal.internal_routes import InternalRoutes

Expand Down Expand Up @@ -152,6 +152,7 @@ def __init__(self, loop):
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'

self.user_manager = UserManager()
self.model_file_manager = ModelFileManager()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None
Expand Down Expand Up @@ -221,20 +222,6 @@ async def get_root(request):
def get_embeddings(self):
embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))

@routes.get("/models")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to keep these routes for now if we want to merge this PR first. If we decide to land this PR after frontend changes has landed, the frontend side needs to be able to handle both APIs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two APIs have not been removed, just moved to model_filemanager/model_filemanager.py. I can restore it if needed.

Copy link
Contributor Author

@hayden-fr hayden-fr Nov 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new API adds the prefix of /experiment. My plan is to retain the original API and then migrate it to the original API after the new function is stable.

def list_model_types(request):
model_types = list(folder_paths.folder_names_and_paths.keys())

return web.json_response(model_types)

@routes.get("/models/{folder}")
async def get_models(request):
folder = request.match_info.get("folder", None)
if not folder in folder_paths.folder_names_and_paths:
return web.Response(status=404)
files = folder_paths.get_filename_list(folder)
return web.json_response(files)

@routes.get("/extensions")
async def get_extensions(request):
Expand Down Expand Up @@ -713,6 +700,7 @@ async def setup(self):

def add_routes(self):
self.user_manager.add_routes(self.routes)
self.model_file_manager.add_routes(self.routes)
self.app.add_subapp('/internal', self.internal_routes.get_app())

# Prefix every route with /api for easier matching for delegation.
Expand Down