Skip to content

Commit

Permalink
diffus multi-user support
Browse files Browse the repository at this point in the history
  • Loading branch information
btian committed Oct 17, 2024
1 parent 6632365 commit 0e97e90
Show file tree
Hide file tree
Showing 42 changed files with 3,863 additions and 504 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@ venv/
*.log
web_custom_versions/
.DS_Store

custom_nodes.bak*
18 changes: 10 additions & 8 deletions api_server/routes/internal/internal_routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from aiohttp import web
from typing import Optional
from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths

import execution_context
from folder_paths import get_models_dir, get_user_directory, get_output_directory, folder_names_and_paths
from api_server.services.file_service import FileService
import app.logger

Expand All @@ -15,34 +17,34 @@ def __init__(self):
self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None
self.file_service = FileService({
"models": models_dir,
"user": user_directory,
"output": output_directory
# "models": get_models_dir,
"user": get_user_directory,
"output": get_output_directory
})

def setup_routes(self):
@self.routes.get('/files')
async def list_files(request):
directory_key = request.query.get('directory', '')
context = execution_context.ExecutionContext(request)
try:
file_list = self.file_service.list_files(directory_key)
file_list = self.file_service.list_files(context, directory_key)
return web.json_response({"files": file_list})
except ValueError as e:
return web.json_response({"error": str(e)}, status=400)
except Exception as e:
return web.json_response({"error": str(e)}, status=500)

@self.routes.get('/logs')
async def get_logs(request):
return web.json_response(app.logger.get_logs())
# return web.json_response(app.logger.get_logs())
return web.json_response([])

@self.routes.get('/folder_paths')
async def get_folder_paths(request):
response = {}
for key in folder_names_and_paths:
response[key] = folder_names_and_paths[key][0]
return web.json_response(response)

def get_app(self):
if self._app is None:
self._app = web.Application()
Expand Down
12 changes: 7 additions & 5 deletions api_server/services/file_service.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Callable

import execution_context
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem

class FileService:
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
self.allowed_directories: Dict[str, str] = allowed_directories
def __init__(self, allowed_directories: Dict[str, Callable], file_system_ops: Optional[FileSystemOperations] = None):
self.allowed_directories: Dict[str, Callable] = allowed_directories
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()

def list_files(self, directory_key: str) -> List[FileSystemItem]:
def list_files(self, context: execution_context.ExecutionContext, directory_key: str) -> List[FileSystemItem]:
if directory_key not in self.allowed_directories:
raise ValueError("Invalid directory key")
directory_path: str = self.allowed_directories[directory_key]
directory_path: str = self.allowed_directories[directory_key](context.user_hash)
return self.file_system_ops.walk_directory(directory_path)
24 changes: 13 additions & 11 deletions app/app_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,19 @@ async def get_setting(request):

@routes.post("/settings")
async def post_settings(request):
settings = self.get_settings(request)
new_settings = await request.json()
self.save_settings(request, {**settings, **new_settings})
return web.Response(status=200)
# settings = self.get_settings(request)
# new_settings = await request.json()
# self.save_settings(request, {**settings, **new_settings})
# return web.Response(status=200)
return web.Response(status=403)

@routes.post("/settings/{id}")
async def post_setting(request):
setting_id = request.match_info.get("id", None)
if not setting_id:
return web.Response(status=400)
settings = self.get_settings(request)
settings[setting_id] = await request.json()
self.save_settings(request, settings)
return web.Response(status=200)
# setting_id = request.match_info.get("id", None)
# if not setting_id:
# return web.Response(status=400)
# settings = self.get_settings(request)
# settings[setting_id] = await request.json()
# self.save_settings(request, settings)
# return web.Response(status=200)
return web.Response(status=403)
114 changes: 61 additions & 53 deletions app/user_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import shutil
from aiohttp import web
from urllib import parse

import execution_context
from comfy.cli_args import args
import folder_paths
from .app_settings import AppSettings
Expand All @@ -15,26 +17,27 @@

class UserManager():
def __init__(self):
user_directory = folder_paths.get_user_directory()
# user_directory = folder_paths.get_user_directory('')

self.settings = AppSettings(self)
if not os.path.exists(user_directory):
os.mkdir(user_directory)
if not args.multi_user:
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")

if args.multi_user:
if os.path.isfile(self.get_users_file()):
with open(self.get_users_file()) as f:
self.users = json.load(f)
else:
self.users = {}
else:
self.users = {"default": "default"}

def get_users_file(self):
return os.path.join(folder_paths.get_user_directory(), "users.json")
# if not os.path.exists(user_directory):
# os.mkdir(user_directory)
# if not args.multi_user:
# print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
# print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")

# if args.multi_user:
# if os.path.isfile(self.get_users_file()):
# with open(self.get_users_file()) as f:
# self.users = json.load(f)
# else:
# self.users = {}
# else:
# self.users = {"default": "default"}
self.users = {"default": "default"}

def get_users_file(self, context: execution_context.ExecutionContext):
return os.path.join(folder_paths.get_user_directory(context.user_hash), "users.json")

def get_request_user_id(self, request):
user = "default"
Expand All @@ -47,14 +50,16 @@ def get_request_user_id(self, request):
return user

def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
user_directory = folder_paths.get_user_directory()
context = execution_context.ExecutionContext(request)
user_directory = folder_paths.get_user_directory(context.user_hash)

if type == "userdata":
root_dir = user_directory
else:
raise KeyError("Unknown filepath type:" + type)

user = self.get_request_user_id(request)
# user = self.get_request_user_id(request)
user = context.user_id
path = user_root = os.path.abspath(os.path.join(root_dir, user))

# prevent leaving /{type}
Expand All @@ -78,7 +83,7 @@ def get_request_user_filepath(self, request, file, type="userdata", create_dir=T

return path

def add_user(self, name):
def add_user(self, context: execution_context.ExecutionContext, name):
name = name.strip()
if not name:
raise ValueError("username not provided")
Expand All @@ -87,7 +92,7 @@ def add_user(self, name):

self.users[user_id] = name

with open(self.get_users_file(), "w") as f:
with open(self.get_users_file(context), "w") as f:
json.dump(self.users, f)

return user_id
Expand All @@ -108,13 +113,14 @@ async def get_users(request):

@routes.post("/users")
async def post_users(request):
body = await request.json()
username = body["username"]
if username in self.users.values():
return web.json_response({"error": "Duplicate username."}, status=400)

user_id = self.add_user(username)
return web.json_response(user_id)
# body = await request.json()
# username = body["username"]
# if username in self.users.values():
# return web.json_response({"error": "Duplicate username."}, status=400)
#
# user_id = self.add_user(username)
# return web.json_response(user_id)
return web.Response(status=403, text="Forbidden")

@routes.get("/userdata")
async def listuserdata(request):
Expand Down Expand Up @@ -226,30 +232,32 @@ async def post_userdata(request):

@routes.delete("/userdata/{file}")
async def delete_userdata(request):
path = get_user_data_path(request, check_exists=True)
if not isinstance(path, str):
return path

os.remove(path)

return web.Response(status=204)
# path = get_user_data_path(request, check_exists=True)
# if not isinstance(path, str):
# return path
#
# os.remove(path)
#
# return web.Response(status=204)
return web.Response(status=403, text="Forbidden")

@routes.post("/userdata/{file}/move/{dest}")
async def move_userdata(request):
source = get_user_data_path(request, check_exists=True)
if not isinstance(source, str):
return source

dest = get_user_data_path(request, check_exists=False, param="dest")
if not isinstance(source, str):
return dest

overwrite = request.query["overwrite"] != "false"
if not overwrite and os.path.exists(dest):
return web.Response(status=409)

print(f"moving '{source}' -> '{dest}'")
shutil.move(source, dest)

resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
return web.json_response(resp)
# source = get_user_data_path(request, check_exists=True)
# if not isinstance(source, str):
# return source
#
# dest = get_user_data_path(request, check_exists=False, param="dest")
# if not isinstance(source, str):
# return dest
#
# overwrite = request.query["overwrite"] != "false"
# if not overwrite and os.path.exists(dest):
# return web.Response(status=409)
#
# print(f"moving '{source}' -> '{dest}'")
# shutil.move(source, dest)
#
# resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
# return web.json_response(resp)
return web.Response(status=403, text="Forbidden")
8 changes: 7 additions & 1 deletion comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
if hasattr(ckpt, 'is_safetensors') and hasattr(ckpt, 'filename'):
is_safetensors = ckpt.is_safetensors
ckpt = ckpt.filename
else:
is_safetensors = ckpt.lower().endswith(".safetensors")

if is_safetensors:
sd = safetensors.torch.load_file(ckpt, device=device.type)
else:
if safe_load:
Expand Down
Loading

0 comments on commit 0e97e90

Please sign in to comment.