Skip to content

Commit

Permalink
diffus multi-user support
Browse files Browse the repository at this point in the history
  • Loading branch information
btian committed Oct 28, 2024
1 parent 6632365 commit ed5b430
Show file tree
Hide file tree
Showing 42 changed files with 3,899 additions and 497 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ venv/
*.log
web_custom_versions/
.DS_Store

custom_nodes.bak*
18 changes: 10 additions & 8 deletions api_server/routes/internal/internal_routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from aiohttp import web
from typing import Optional
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths

import execution_context
from folder_paths import get_models_dir, get_user_directory, get_output_directory, folder_names_and_paths
from api_server.services.file_service import FileService
import app.logger

Expand All @@ -15,34 +17,34 @@ def __init__(self):
self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None
self.file_service = FileService({
"models": models_dir,
"user": user_directory,
"output": output_directory
# "models": get_models_dir,
"user": get_user_directory,
"output": get_output_directory
})

def setup_routes(self):
@self.routes.get('/files')
async def list_files(request):
directory_key = request.query.get('directory', '')
context = execution_context.ExecutionContext(request)
try:
file_list = self.file_service.list_files(directory_key)
file_list = self.file_service.list_files(context, directory_key)
return web.json_response({"files": file_list})
except ValueError as e:
return web.json_response({"error": str(e)}, status=400)
except Exception as e:
return web.json_response({"error": str(e)}, status=500)

@self.routes.get('/logs')
async def get_logs(request):
return web.json_response(app.logger.get_logs())
# return web.json_response(app.logger.get_logs())
return web.json_response([])

@self.routes.get('/folder_paths')
async def get_folder_paths(request):
response = {}
for key in folder_names_and_paths:
response[key] = folder_names_and_paths[key][0]
return web.json_response(response)

def get_app(self):
if self._app is None:
self._app = web.Application()
Expand Down
12 changes: 7 additions & 5 deletions api_server/services/file_service.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Callable

import execution_context
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem

class FileService:
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
self.allowed_directories: Dict[str, str] = allowed_directories
def __init__(self, allowed_directories: Dict[str, Callable], file_system_ops: Optional[FileSystemOperations] = None):
self.allowed_directories: Dict[str, Callable] = allowed_directories
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()

def list_files(self, directory_key: str) -> List[FileSystemItem]:
def list_files(self, context: execution_context.ExecutionContext, directory_key: str) -> List[FileSystemItem]:
if directory_key not in self.allowed_directories:
raise ValueError("Invalid directory key")
directory_path: str = self.allowed_directories[directory_key]
directory_path: str = self.allowed_directories[directory_key](context.user_hash)
return self.file_system_ops.walk_directory(directory_path)
10 changes: 7 additions & 3 deletions app/app_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ def __init__(self, user_manager):
self.user_manager = user_manager

def get_settings(self, request):
file = self.user_manager.get_request_user_filepath(
request, "comfy.settings.json")
# use user private settings
file = self.user_manager.get_request_user_filepath(request, "comfy.settings.json")
if not os.path.isfile(file):
# use default user settings
file = self.user_manager.get_default_user_filepath("comfy.settings.json")

if os.path.isfile(file):
with open(file) as f:
return json.load(f)
Expand Down Expand Up @@ -51,4 +55,4 @@ async def post_setting(request):
settings = self.get_settings(request)
settings[setting_id] = await request.json()
self.save_settings(request, settings)
return web.Response(status=200)
return web.Response(status=200)
123 changes: 69 additions & 54 deletions app/user_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import shutil
from aiohttp import web
from urllib import parse

import execution_context
from comfy.cli_args import args
import folder_paths
from .app_settings import AppSettings
Expand All @@ -15,26 +17,27 @@

class UserManager():
def __init__(self):
user_directory = folder_paths.get_user_directory()
# user_directory = folder_paths.get_user_directory('')

self.settings = AppSettings(self)
if not os.path.exists(user_directory):
os.mkdir(user_directory)
if not args.multi_user:
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")

if args.multi_user:
if os.path.isfile(self.get_users_file()):
with open(self.get_users_file()) as f:
self.users = json.load(f)
else:
self.users = {}
else:
self.users = {"default": "default"}

def get_users_file(self):
return os.path.join(folder_paths.get_user_directory(), "users.json")
# if not os.path.exists(user_directory):
# os.mkdir(user_directory)
# if not args.multi_user:
# print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
# print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")

# if args.multi_user:
# if os.path.isfile(self.get_users_file()):
# with open(self.get_users_file()) as f:
# self.users = json.load(f)
# else:
# self.users = {}
# else:
# self.users = {"default": "default"}
self.users = {"default": "default"}

def get_users_file(self, context: execution_context.ExecutionContext):
return os.path.join(folder_paths.get_user_directory(context.user_hash), "users.json")

def get_request_user_id(self, request):
user = "default"
Expand All @@ -46,16 +49,25 @@ def get_request_user_id(self, request):

return user

def get_default_user_filepath(self, file, type="userdata"):
return self._get_request_user_filepath("default", file, type, False)

def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
user_directory = folder_paths.get_user_directory()
context = execution_context.ExecutionContext(request)
if not context.user_hash:
raise Exception("user hash is not provided")
return self._get_request_user_filepath(context.user_hash, file, type, create_dir)

def _get_request_user_filepath(self, user_hash, file, type="userdata", create_dir=True):
user_directory = folder_paths.get_user_directory(user_hash)

if type == "userdata":
root_dir = user_directory
else:
raise KeyError("Unknown filepath type:" + type)

user = self.get_request_user_id(request)
path = user_root = os.path.abspath(os.path.join(root_dir, user))
# user = self.get_request_user_id(request)
path = user_root = os.path.abspath(root_dir)

# prevent leaving /{type}
if os.path.commonpath((root_dir, user_root)) != root_dir:
Expand All @@ -78,7 +90,7 @@ def get_request_user_filepath(self, request, file, type="userdata", create_dir=T

return path

def add_user(self, name):
def add_user(self, context: execution_context.ExecutionContext, name):
name = name.strip()
if not name:
raise ValueError("username not provided")
Expand All @@ -87,7 +99,7 @@ def add_user(self, name):

self.users[user_id] = name

with open(self.get_users_file(), "w") as f:
with open(self.get_users_file(context), "w") as f:
json.dump(self.users, f)

return user_id
Expand All @@ -108,13 +120,14 @@ async def get_users(request):

@routes.post("/users")
async def post_users(request):
body = await request.json()
username = body["username"]
if username in self.users.values():
return web.json_response({"error": "Duplicate username."}, status=400)

user_id = self.add_user(username)
return web.json_response(user_id)
# body = await request.json()
# username = body["username"]
# if username in self.users.values():
# return web.json_response({"error": "Duplicate username."}, status=400)
#
# user_id = self.add_user(username)
# return web.json_response(user_id)
return web.Response(status=403, text="Forbidden")

@routes.get("/userdata")
async def listuserdata(request):
Expand Down Expand Up @@ -226,30 +239,32 @@ async def post_userdata(request):

@routes.delete("/userdata/{file}")
async def delete_userdata(request):
path = get_user_data_path(request, check_exists=True)
if not isinstance(path, str):
return path

os.remove(path)

return web.Response(status=204)
# path = get_user_data_path(request, check_exists=True)
# if not isinstance(path, str):
# return path
#
# os.remove(path)
#
# return web.Response(status=204)
return web.Response(status=403, text="Forbidden")

@routes.post("/userdata/{file}/move/{dest}")
async def move_userdata(request):
source = get_user_data_path(request, check_exists=True)
if not isinstance(source, str):
return source

dest = get_user_data_path(request, check_exists=False, param="dest")
if not isinstance(source, str):
return dest

overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(dest):
return web.Response(status=409)

print(f"moving '{source}' -> '{dest}'")
shutil.move(source, dest)

resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
return web.json_response(resp)
# source = get_user_data_path(request, check_exists=True)
# if not isinstance(source, str):
# return source
#
# dest = get_user_data_path(request, check_exists=False, param="dest")
# if not isinstance(source, str):
# return dest
#
# overwrite = request.query["overwrite"] != "false"
# if not overwrite and os.path.exists(dest):
# return web.Response(status=409)
#
# print(f"moving '{source}' -> '{dest}'")
# shutil.move(source, dest)
#
# resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
# return web.json_response(resp)
return web.Response(status=403, text="Forbidden")
8 changes: 7 additions & 1 deletion comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
if hasattr(ckpt, 'is_safetensors') and hasattr(ckpt, 'filename'):
is_safetensors = ckpt.is_safetensors
ckpt = ckpt.filename
else:
is_safetensors = ckpt.lower().endswith(".safetensors")

if is_safetensors:
sd = safetensors.torch.load_file(ckpt, device=device.type)
else:
if safe_load:
Expand Down
Loading

0 comments on commit ed5b430

Please sign in to comment.