diff --git a/invokeai/app/services/config.py b/invokeai/app/services/config.py index 7b5015a670b..c407132e932 100644 --- a/invokeai/app/services/config.py +++ b/invokeai/app/services/config.py @@ -28,7 +28,6 @@ always_use_cpu: false free_gpu_mem: false Features: - restore: true esrgan: true patchmatch: true internet_available: true @@ -165,7 +164,7 @@ class InvokeBatch(InvokeAISettings): import os import sys from argparse import ArgumentParser -from omegaconf import OmegaConf, DictConfig +from omegaconf import OmegaConf, DictConfig, ListConfig from pathlib import Path from pydantic import BaseSettings, Field, parse_obj_as from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args @@ -189,7 +188,12 @@ def parse_args(self, argv: list = sys.argv[1:]): opt = parser.parse_args(argv) for name in self.__fields__: if name not in self._excluded(): - setattr(self, name, getattr(opt, name)) + value = getattr(opt, name) + if isinstance(value, ListConfig): + value = list(value) + elif isinstance(value, DictConfig): + value = dict(value) + setattr(self, name, value) def to_yaml(self) -> str: """ @@ -282,14 +286,10 @@ def _excluded_from_yaml(self) -> List[str]: return [ "type", "initconf", - "gpu_mem_reserved", - "max_loaded_models", "version", "from_file", "model", - "restore", "root", - "nsfw_checker", ] class Config: @@ -388,15 +388,11 @@ class InvokeAIAppConfig(InvokeAISettings): internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features') log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features') patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features') - restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED') always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance') free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance') - max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='DEPRECATED') max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance') max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance') - gpu_mem_reserved : float = Field(default=2.75, ge=0, description="DEPRECATED: use max_vram_cache_size. Amount of VRAM reserved for model storage", category='DEPRECATED') - nsfw_checker : bool = Field(default=True, description="DEPRECATED: use Web settings to enable/disable", category='DEPRECATED') precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance') sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance') xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') @@ -414,9 +410,7 @@ class InvokeAIAppConfig(InvokeAISettings): outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths') from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths') use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths') - ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert') - - model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models') + ignore_missing_core_models : bool = Field(default=False, description='Ignore missing models in models/core/convert', category='Features') log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=", "syslog=path|address:host:port", "http="', category="Logging") # note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues @@ -426,6 +420,9 @@ class InvokeAIAppConfig(InvokeAISettings): version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other") # fmt: on + class Config: + validate_assignment = True + def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False): """ Update settings with contents of init file, environment, and diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index 4bf2a484a19..714b688996e 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -44,6 +44,8 @@ ) from invokeai.backend.util.logging import InvokeAILogger from invokeai.frontend.install.model_install import addModelsForm, process_and_execute + +# TO DO - Move all the frontend code into invokeai.frontend.install from invokeai.frontend.install.widgets import ( SingleSelectColumns, CenteredButtonPress, @@ -61,6 +63,7 @@ ModelInstall, ) from invokeai.backend.model_management.model_probe import ModelType, BaseModelType +from pydantic.error_wrappers import ValidationError warnings.filterwarnings("ignore") transformers.logging.set_verbosity_error() @@ -654,10 +657,13 @@ def migrate_init_file(legacy_format: Path): old = legacy_parser.parse_args([f"@{str(legacy_format)}"]) new = InvokeAIAppConfig.get_config() - fields = list(get_type_hints(InvokeAIAppConfig).keys()) + fields = [x for x, y in InvokeAIAppConfig.__fields__.items() if y.field_info.extra.get("category") != "DEPRECATED"] for attr in fields: if hasattr(old, attr): - setattr(new, attr, getattr(old, attr)) + try: + setattr(new, attr, getattr(old, attr)) + except ValidationError as e: + print(f"* Ignoring incompatible value for field {attr}:\n {str(e)}") # a few places where the field names have changed and we have to # manually add in the new names/values