Skip to content

Commit

Permalink
feat(mm): probe for main model default settings
Browse files Browse the repository at this point in the history
Currently, this is just the width and height, derived from the model base.
  • Loading branch information
psychedelicious committed Mar 12, 2024
1 parent 7d5e54d commit e618f44
Showing 1 changed file with 46 additions and 38 deletions.
84 changes: 46 additions & 38 deletions invokeai/backend/model_manager/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
BaseModelType,
ControlAdapterDefaultSettings,
InvalidModelConfigException,
MainModelDefaultSettings,
ModelConfigFactory,
ModelFormat,
ModelRepoVariant,
Expand Down Expand Up @@ -162,11 +163,13 @@ def probe(
fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)

fields["default_settings"] = (
fields.get("default_settings") or probe.get_default_settings(fields["name"])
if isinstance(probe, ControlAdapterProbe)
else None
)
fields["default_settings"] = fields.get("default_settings")

if not fields["default_settings"]:
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter}:
fields["default_settings"] = get_default_settings_controlnet_t2i_adapter(fields["name"])
elif fields["type"] is ModelType.Main:
fields["default_settings"] = get_default_settings_main(fields["base"])

if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
Expand Down Expand Up @@ -338,36 +341,41 @@ def _scan_model(cls, model_name: str, checkpoint: Path) -> None:
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")


class ControlAdapterProbe(ProbeBase):
"""Adds `get_default_settings` for ControlNet and T2IAdapter probes"""

# TODO(psyche): It would be nice to get these from the invocations, but that creates circular dependencies.
# "canny": CannyImageProcessorInvocation.get_type()
MODEL_NAME_TO_PREPROCESSOR = {
"canny": "canny_image_processor",
"mlsd": "mlsd_image_processor",
"depth": "depth_anything_image_processor",
"bae": "normalbae_image_processor",
"normal": "normalbae_image_processor",
"sketch": "pidi_image_processor",
"scribble": "lineart_image_processor",
"lineart": "lineart_image_processor",
"lineart_anime": "lineart_anime_image_processor",
"softedge": "hed_image_processor",
"shuffle": "content_shuffle_image_processor",
"pose": "dw_openpose_image_processor",
"mediapipe": "mediapipe_face_processor",
"pidi": "pidi_image_processor",
"zoe": "zoe_depth_image_processor",
"color": "color_map_image_processor",
}
# Probing utilities
MODEL_NAME_TO_PREPROCESSOR = {
"canny": "canny_image_processor",
"mlsd": "mlsd_image_processor",
"depth": "depth_anything_image_processor",
"bae": "normalbae_image_processor",
"normal": "normalbae_image_processor",
"sketch": "pidi_image_processor",
"scribble": "lineart_image_processor",
"lineart": "lineart_image_processor",
"lineart_anime": "lineart_anime_image_processor",
"softedge": "hed_image_processor",
"shuffle": "content_shuffle_image_processor",
"pose": "dw_openpose_image_processor",
"mediapipe": "mediapipe_face_processor",
"pidi": "pidi_image_processor",
"zoe": "zoe_depth_image_processor",
"color": "color_map_image_processor",
}

@classmethod
def get_default_settings(cls, model_name: str) -> Optional[ControlAdapterDefaultSettings]:
for k, v in cls.MODEL_NAME_TO_PREPROCESSOR.items():
if k in model_name:
return ControlAdapterDefaultSettings(preprocessor=v)
return None

def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
if k in model_name:
return ControlAdapterDefaultSettings(preprocessor=v)
return None


def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
return MainModelDefaultSettings(width=512, height=512)
elif model_base is BaseModelType.StableDiffusionXL:
return MainModelDefaultSettings(width=1024, height=1024)
# We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models.
return None


# ##################################################3
Expand Down Expand Up @@ -493,7 +501,7 @@ def get_base_type(self) -> BaseModelType:
raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type")


class ControlNetCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe):
class ControlNetCheckpointProbe(CheckpointProbeBase):
"""Class for probing controlnets."""

def get_base_type(self) -> BaseModelType:
Expand Down Expand Up @@ -521,7 +529,7 @@ def get_base_type(self) -> BaseModelType:
raise NotImplementedError()


class T2IAdapterCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe):
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()

Expand Down Expand Up @@ -659,7 +667,7 @@ def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal


class ControlNetFolderProbe(FolderProbeBase, ControlAdapterProbe):
class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
if not config_file.exists():
Expand Down Expand Up @@ -733,7 +741,7 @@ def get_base_type(self) -> BaseModelType:
return BaseModelType.Any


class T2IAdapterFolderProbe(FolderProbeBase, ControlAdapterProbe):
class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
if not config_file.exists():
Expand Down

0 comments on commit e618f44

Please sign in to comment.