From 3b2c605aa6a44af0f97e63a638d0beb0f7f6f5e3 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Thu, 29 Feb 2024 15:13:38 -0500 Subject: [PATCH] Next: Allow in place local installs of models --- invokeai/app/api/routers/model_manager.py | 2 ++ invokeai/app/services/model_install/model_install_default.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 50ebe5ce647..47e0fd314e8 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -451,6 +451,7 @@ async def add_model_record( ) async def install_model( source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"), + inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False), # TODO(MM2): Can we type this? config: Optional[Dict[str, Any]] = Body( description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ", @@ -493,6 +494,7 @@ async def install_model( source=source, config=config, access_token=access_token, + inplace=bool(inplace), ) logger.info(f"Started installation of {source}") except UnknownModelException as e: diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index b3c6015f4df..b91f9610992 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -178,13 +178,14 @@ def heuristic_import( source: str, config: Optional[Dict[str, Any]] = None, access_token: Optional[str] = None, + inplace: bool = False, ) -> ModelInstallJob: variants = "|".join(ModelRepoVariant.__members__.values()) hf_repoid_re = f"^([^/:]+/[^/:]+)(?::({variants})?(?::/?([^:]+))?)?$" source_obj: Optional[StringLikeSource] = None if Path(source).exists(): # A local file or directory - source_obj = LocalModelSource(path=Path(source)) + source_obj = LocalModelSource(path=Path(source), inplace=inplace) elif match := re.match(hf_repoid_re, source): source_obj = HFModelSource( repo_id=match.group(1),