Skip to content

Commit

Permalink
feat: merged from archive/diffus-20241012
Browse files Browse the repository at this point in the history
  • Loading branch information
btian committed Dec 10, 2024
1 parent 9a616b8 commit 5722402
Show file tree
Hide file tree
Showing 44 changed files with 3,963 additions and 555 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ venv/
*.log
web_custom_versions/
.DS_Store

custom_nodes.bak*
40 changes: 23 additions & 17 deletions api_server/routes/internal/internal_routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from aiohttp import web
from typing import Optional
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths

import execution_context
from folder_paths import get_models_dir, get_user_directory, get_output_directory, folder_names_and_paths
from api_server.services.file_service import FileService
from api_server.services.terminal_service import TerminalService
import app.logger
Expand All @@ -16,9 +18,9 @@ def __init__(self, prompt_server):
self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None
self.file_service = FileService({
"models": models_dir,
"user": user_directory,
"output": output_directory
# "models": get_models_dir,
"user": get_user_directory,
"output": get_output_directory
})
self.prompt_server = prompt_server
self.terminal_service = TerminalService(prompt_server)
Expand All @@ -27,39 +29,43 @@ def setup_routes(self):
@self.routes.get('/files')
async def list_files(request):
directory_key = request.query.get('directory', '')
context = execution_context.ExecutionContext(request)
try:
file_list = self.file_service.list_files(directory_key)
file_list = self.file_service.list_files(context, directory_key)
return web.json_response({"files": file_list})
except ValueError as e:
return web.json_response({"error": str(e)}, status=400)
except Exception as e:
return web.json_response({"error": str(e)}, status=500)

@self.routes.get('/logs')
async def get_logs(request):
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
# return web.json_response(app.logger.get_logs())
return web.json_response([])

@self.routes.get('/logs/raw')
async def get_logs(request):
self.terminal_service.update_size()
# return web.json_response({
# "entries": list(app.logger.get_logs()),
# "size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
# })
return web.json_response({
"entries": list(app.logger.get_logs()),
"size": {"cols": self.terminal_service.cols, "rows": self.terminal_service.rows}
"entries": [],
"size": {"cols": 0, "rows": 0}
})

@self.routes.patch('/logs/subscribe')
async def subscribe_logs(request):
json_data = await request.json()
client_id = json_data["clientId"]
enabled = json_data["enabled"]
if enabled:
self.terminal_service.subscribe(client_id)
else:
self.terminal_service.unsubscribe(client_id)
# json_data = await request.json()
# client_id = json_data["clientId"]
# enabled = json_data["enabled"]
# if enabled:
# self.terminal_service.subscribe(client_id)
# else:
# self.terminal_service.unsubscribe(client_id)

return web.Response(status=200)


@self.routes.get('/folder_paths')
async def get_folder_paths(request):
response = {}
Expand Down
12 changes: 7 additions & 5 deletions api_server/services/file_service.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Callable

import execution_context
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem

class FileService:
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
self.allowed_directories: Dict[str, str] = allowed_directories
def __init__(self, allowed_directories: Dict[str, Callable], file_system_ops: Optional[FileSystemOperations] = None):
self.allowed_directories: Dict[str, Callable] = allowed_directories
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()

def list_files(self, directory_key: str) -> List[FileSystemItem]:
def list_files(self, context: execution_context.ExecutionContext, directory_key: str) -> List[FileSystemItem]:
if directory_key not in self.allowed_directories:
raise ValueError("Invalid directory key")
directory_path: str = self.allowed_directories[directory_key]
directory_path: str = self.allowed_directories[directory_key](context.user_hash)
return self.file_system_ops.walk_directory(directory_path)
10 changes: 7 additions & 3 deletions app/app_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ def __init__(self, user_manager):
self.user_manager = user_manager

def get_settings(self, request):
file = self.user_manager.get_request_user_filepath(
request, "comfy.settings.json")
# use user private settings
file = self.user_manager.get_request_user_filepath(request, "comfy.settings.json")
if not os.path.isfile(file):
# use default user settings
file = self.user_manager.get_default_user_filepath("comfy.settings.json")

if os.path.isfile(file):
with open(file) as f:
return json.load(f)
Expand Down Expand Up @@ -51,4 +55,4 @@ async def post_setting(request):
settings = self.get_settings(request)
settings[setting_id] = await request.json()
self.save_settings(request, settings)
return web.Response(status=200)
return web.Response(status=200)
154 changes: 69 additions & 85 deletions app/user_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import logging
from aiohttp import web
from urllib import parse

import execution_context
from comfy.cli_args import args
import folder_paths
from .app_settings import AppSettings
Expand All @@ -32,26 +34,27 @@ def get_file_info(path: str, relative_to: str) -> FileInfo:

class UserManager():
def __init__(self):
user_directory = folder_paths.get_user_directory()
# user_directory = folder_paths.get_user_directory('')

self.settings = AppSettings(self)
if not os.path.exists(user_directory):
os.makedirs(user_directory, exist_ok=True)
if not args.multi_user:
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")

if args.multi_user:
if os.path.isfile(self.get_users_file()):
with open(self.get_users_file()) as f:
self.users = json.load(f)
else:
self.users = {}
else:
self.users = {"default": "default"}

def get_users_file(self):
return os.path.join(folder_paths.get_user_directory(), "users.json")
# if not os.path.exists(user_directory):
# os.mkdir(user_directory)
# if not args.multi_user:
# print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
# print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")

# if args.multi_user:
# if os.path.isfile(self.get_users_file()):
# with open(self.get_users_file()) as f:
# self.users = json.load(f)
# else:
# self.users = {}
# else:
# self.users = {"default": "default"}
self.users = {"default": "default"}

def get_users_file(self, context: execution_context.ExecutionContext):
return os.path.join(folder_paths.get_user_directory(context.user_hash), "users.json")

def get_request_user_id(self, request):
user = "default"
Expand All @@ -63,16 +66,25 @@ def get_request_user_id(self, request):

return user

def get_default_user_filepath(self, file, type="userdata"):
return self._get_request_user_filepath("default", file, type, False)

def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
user_directory = folder_paths.get_user_directory()
context = execution_context.ExecutionContext(request)
if not context.user_hash:
raise Exception("user hash is not provided")
return self._get_request_user_filepath(context.user_hash, file, type, create_dir)

def _get_request_user_filepath(self, user_hash, file, type="userdata", create_dir=True):
user_directory = folder_paths.get_user_directory(user_hash)

if type == "userdata":
root_dir = user_directory
else:
raise KeyError("Unknown filepath type:" + type)

user = self.get_request_user_id(request)
path = user_root = os.path.abspath(os.path.join(root_dir, user))
# user = self.get_request_user_id(request)
path = user_root = os.path.abspath(root_dir)

# prevent leaving /{type}
if os.path.commonpath((root_dir, user_root)) != root_dir:
Expand All @@ -95,7 +107,7 @@ def get_request_user_filepath(self, request, file, type="userdata", create_dir=T

return path

def add_user(self, name):
def add_user(self, context: execution_context.ExecutionContext, name):
name = name.strip()
if not name:
raise ValueError("username not provided")
Expand All @@ -104,7 +116,7 @@ def add_user(self, name):

self.users[user_id] = name

with open(self.get_users_file(), "w") as f:
with open(self.get_users_file(context), "w") as f:
json.dump(self.users, f)

return user_id
Expand All @@ -125,13 +137,14 @@ async def get_users(request):

@routes.post("/users")
async def post_users(request):
body = await request.json()
username = body["username"]
if username in self.users.values():
return web.json_response({"error": "Duplicate username."}, status=400)

user_id = self.add_user(username)
return web.json_response(user_id)
# body = await request.json()
# username = body["username"]
# if username in self.users.values():
# return web.json_response({"error": "Duplicate username."}, status=400)
#
# user_id = self.add_user(username)
# return web.json_response(user_id)
return web.Response(status=403, text="Forbidden")

@routes.get("/userdata")
async def listuserdata(request):
Expand Down Expand Up @@ -270,61 +283,32 @@ async def post_userdata(request):

@routes.delete("/userdata/{file}")
async def delete_userdata(request):
path = get_user_data_path(request, check_exists=True)
if not isinstance(path, str):
return path

os.remove(path)

return web.Response(status=204)
# path = get_user_data_path(request, check_exists=True)
# if not isinstance(path, str):
# return path
#
# os.remove(path)
#
# return web.Response(status=204)
return web.Response(status=403, text="Forbidden")

@routes.post("/userdata/{file}/move/{dest}")
async def move_userdata(request):
"""
Move or rename a user data file.
This endpoint handles moving or renaming files within a user's data directory, with options for
controlling overwrite behavior and response format.
Path Parameters:
- file: The source file path (URL encoded if necessary)
- dest: The destination file path (URL encoded if necessary)
Query Parameters:
- overwrite (optional): If "false", prevents overwriting existing files. Defaults to "true".
- full_info (optional): If "true", returns detailed file information (path, size, modified time).
If "false", returns only the relative file path.
Returns:
- 400: If either 'file' or 'dest' parameter is missing
- 403: If either requested path is not allowed
- 404: If the source file does not exist
- 409: If overwrite=false and the destination file already exists
- 200: JSON response with either:
- Full file information (if full_info=true)
- Relative file path (if full_info=false)
"""
source = get_user_data_path(request, check_exists=True)
if not isinstance(source, str):
return source

dest = get_user_data_path(request, check_exists=False, param="dest")
if not isinstance(source, str):
return dest

overwrite = request.query.get("overwrite", 'true') != "false"
full_info = request.query.get('full_info', 'false').lower() == "true"

if not overwrite and os.path.exists(dest):
return web.Response(status=409, text="File already exists")

logging.info(f"moving '{source}' -> '{dest}'")
shutil.move(source, dest)

user_path = self.get_request_user_filepath(request, None)
if full_info:
resp = get_file_info(dest, user_path)
else:
resp = os.path.relpath(dest, user_path)

return web.json_response(resp)
# source = get_user_data_path(request, check_exists=True)
# if not isinstance(source, str):
# return source
#
# dest = get_user_data_path(request, check_exists=False, param="dest")
# if not isinstance(source, str):
# return dest
#
# overwrite = request.query["overwrite"] != "false"
# if not overwrite and os.path.exists(dest):
# return web.Response(status=409)
#
# print(f"moving '{source}' -> '{dest}'")
# shutil.move(source, dest)
#
# resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
# return web.json_response(resp)
return web.Response(status=403, text="Forbidden")
8 changes: 7 additions & 1 deletion comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
if hasattr(ckpt, 'is_safetensors') and hasattr(ckpt, 'filename'):
is_safetensors = ckpt.is_safetensors
ckpt = ckpt.filename
else:
is_safetensors = ckpt.lower().endswith(".safetensors")

if is_safetensors:
sd = safetensors.torch.load_file(ckpt, device=device.type)
else:
if safe_load:
Expand Down
Loading

0 comments on commit 5722402

Please sign in to comment.