diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index a9827fb5643..edcb3acceb4 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -16,6 +16,7 @@ BaseModelType, ControlAdapterDefaultSettings, InvalidModelConfigException, + MainModelDefaultSettings, ModelConfigFactory, ModelFormat, ModelRepoVariant, @@ -162,11 +163,13 @@ def probe( fields["format"] = fields.get("format") or probe.get_format() fields["hash"] = fields.get("hash") or ModelHash().hash(model_path) - fields["default_settings"] = ( - fields.get("default_settings") or probe.get_default_settings(fields["name"]) - if isinstance(probe, ControlAdapterProbe) - else None - ) + fields["default_settings"] = fields.get("default_settings") + + if not fields["default_settings"]: + if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter}: + fields["default_settings"] = get_default_settings_controlnet_t2i_adapter(fields["name"]) + elif fields["type"] is ModelType.Main: + fields["default_settings"] = get_default_settings_main(fields["base"]) if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase): fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() @@ -338,36 +341,41 @@ def _scan_model(cls, model_name: str, checkpoint: Path) -> None: raise Exception("The model {model_name} is potentially infected by malware. Aborting import.") -class ControlAdapterProbe(ProbeBase): - """Adds `get_default_settings` for ControlNet and T2IAdapter probes""" - - # TODO(psyche): It would be nice to get these from the invocations, but that creates circular dependencies. - # "canny": CannyImageProcessorInvocation.get_type() - MODEL_NAME_TO_PREPROCESSOR = { - "canny": "canny_image_processor", - "mlsd": "mlsd_image_processor", - "depth": "depth_anything_image_processor", - "bae": "normalbae_image_processor", - "normal": "normalbae_image_processor", - "sketch": "pidi_image_processor", - "scribble": "lineart_image_processor", - "lineart": "lineart_image_processor", - "lineart_anime": "lineart_anime_image_processor", - "softedge": "hed_image_processor", - "shuffle": "content_shuffle_image_processor", - "pose": "dw_openpose_image_processor", - "mediapipe": "mediapipe_face_processor", - "pidi": "pidi_image_processor", - "zoe": "zoe_depth_image_processor", - "color": "color_map_image_processor", - } +# Probing utilities +MODEL_NAME_TO_PREPROCESSOR = { + "canny": "canny_image_processor", + "mlsd": "mlsd_image_processor", + "depth": "depth_anything_image_processor", + "bae": "normalbae_image_processor", + "normal": "normalbae_image_processor", + "sketch": "pidi_image_processor", + "scribble": "lineart_image_processor", + "lineart": "lineart_image_processor", + "lineart_anime": "lineart_anime_image_processor", + "softedge": "hed_image_processor", + "shuffle": "content_shuffle_image_processor", + "pose": "dw_openpose_image_processor", + "mediapipe": "mediapipe_face_processor", + "pidi": "pidi_image_processor", + "zoe": "zoe_depth_image_processor", + "color": "color_map_image_processor", +} - @classmethod - def get_default_settings(cls, model_name: str) -> Optional[ControlAdapterDefaultSettings]: - for k, v in cls.MODEL_NAME_TO_PREPROCESSOR.items(): - if k in model_name: - return ControlAdapterDefaultSettings(preprocessor=v) - return None + +def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]: + for k, v in MODEL_NAME_TO_PREPROCESSOR.items(): + if k in model_name: + return ControlAdapterDefaultSettings(preprocessor=v) + return None + + +def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]: + if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2: + return MainModelDefaultSettings(width=512, height=512) + elif model_base is BaseModelType.StableDiffusionXL: + return MainModelDefaultSettings(width=1024, height=1024) + # We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models. + return None # ##################################################3 @@ -493,7 +501,7 @@ def get_base_type(self) -> BaseModelType: raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type") -class ControlNetCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe): +class ControlNetCheckpointProbe(CheckpointProbeBase): """Class for probing controlnets.""" def get_base_type(self) -> BaseModelType: @@ -521,7 +529,7 @@ def get_base_type(self) -> BaseModelType: raise NotImplementedError() -class T2IAdapterCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe): +class T2IAdapterCheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: raise NotImplementedError() @@ -659,7 +667,7 @@ def get_variant_type(self) -> ModelVariantType: return ModelVariantType.Normal -class ControlNetFolderProbe(FolderProbeBase, ControlAdapterProbe): +class ControlNetFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: config_file = self.model_path / "config.json" if not config_file.exists(): @@ -733,7 +741,7 @@ def get_base_type(self) -> BaseModelType: return BaseModelType.Any -class T2IAdapterFolderProbe(FolderProbeBase, ControlAdapterProbe): +class T2IAdapterFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: config_file = self.model_path / "config.json" if not config_file.exists():