From 4563b6a783beb61ec55f3d43d1cb64a5dfba35e7 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 23 Apr 2024 22:38:02 +1000 Subject: [PATCH 1/2] fix(nodes): fix nsfw checker model download --- invokeai/app/api/routers/app_info.py | 5 +- invokeai/backend/image_util/safety_checker.py | 48 ++++++++++++++----- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index 21286ac2b03..c3bc98a0387 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -13,7 +13,6 @@ from invokeai.app.invocations.upscale import ESRGAN_MODELS from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch -from invokeai.backend.image_util.safety_checker import SafetyChecker from invokeai.backend.util.logging import logging from invokeai.version import __version__ @@ -109,9 +108,7 @@ async def get_config() -> AppConfig: upscaling_models.append(str(Path(model).stem)) upscaler = Upscaler(upscaling_method="esrgan", upscaling_models=upscaling_models) - nsfw_methods = [] - if SafetyChecker.safety_checker_available(): - nsfw_methods.append("nsfw_checker") + nsfw_methods = ["nsfw_checker"] watermarking_methods = ["invisible_watermark"] diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py index 60dcd93fcc5..4e0bfe56e56 100644 --- a/invokeai/backend/image_util/safety_checker.py +++ b/invokeai/backend/image_util/safety_checker.py @@ -8,7 +8,7 @@ import numpy as np from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from PIL import Image +from PIL import Image, ImageFilter from transformers import AutoFeatureExtractor import invokeai.backend.util.logging as logger @@ -16,6 +16,7 @@ from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.silence_warnings import SilenceWarnings +repo_id = "CompVis/stable-diffusion-safety-checker" CHECKER_PATH = "core/convert/stable-diffusion-safety-checker" @@ -24,30 +25,30 @@ class SafetyChecker: Wrapper around SafetyChecker model. """ - safety_checker = None feature_extractor = None - tried_load: bool = False + safety_checker = None @classmethod def _load_safety_checker(cls): - if cls.tried_load: + if cls.safety_checker is not None and cls.feature_extractor is not None: return try: - cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(get_config().models_path / CHECKER_PATH) - cls.feature_extractor = AutoFeatureExtractor.from_pretrained(get_config().models_path / CHECKER_PATH) + model_path = get_config().models_path / CHECKER_PATH + if model_path.exists(): + cls.feature_extractor = AutoFeatureExtractor.from_pretrained(model_path) + cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(model_path) + else: + model_path.mkdir(parents=True, exist_ok=True) + cls.feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id) + cls.feature_extractor.save_pretrained(model_path, safe_serialization=True) + cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(repo_id) + cls.safety_checker.save_pretrained(model_path, safe_serialization=True) except Exception as e: logger.warning(f"Could not load NSFW checker: {str(e)}") - cls.tried_load = True - - @classmethod - def safety_checker_available(cls) -> bool: - return Path(get_config().models_path, CHECKER_PATH).exists() @classmethod def has_nsfw_concept(cls, image: Image.Image) -> bool: - if not cls.safety_checker_available() and cls.tried_load: - return False cls._load_safety_checker() if cls.safety_checker is None or cls.feature_extractor is None: return False @@ -60,3 +61,24 @@ def has_nsfw_concept(cls, image: Image.Image) -> bool: with SilenceWarnings(): checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values) return has_nsfw_concept[0] + + @classmethod + def blur_if_nsfw(cls, image: Image.Image) -> Image.Image: + if cls.has_nsfw_concept(image): + logger.info("A potentially NSFW image has been detected. Image will be blurred.") + blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32)) + caution = cls._get_caution_img() + # Center the caution image on the blurred image + x = (blurry_image.width - caution.width) // 2 + y = (blurry_image.height - caution.height) // 2 + blurry_image.paste(caution, (x, y), caution) + image = blurry_image + + return image + + @classmethod + def _get_caution_img(cls) -> Image.Image: + import invokeai.app.assets.images as image_assets + + caution = Image.open(Path(image_assets.__path__[0]) / "caution.png") + return caution.resize((caution.width // 2, caution.height // 2)) From 18179147154debbcf5fc86e777b72dc8a3946e1e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 13 May 2024 18:33:25 +1000 Subject: [PATCH 2/2] feat(nodes): use new `blur_if_nsfw` method --- invokeai/app/invocations/image.py | 16 ++-------------- invokeai/backend/image_util/safety_checker.py | 2 +- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 05ffc0d67b0..65e7ce5e067 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -1,6 +1,5 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from pathlib import Path from typing import Literal, Optional import cv2 @@ -504,7 +503,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: title="Blur NSFW Image", tags=["image", "nsfw"], category="image", - version="1.2.2", + version="1.2.3", ) class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard): """Add blur to NSFW-flagged images""" @@ -516,23 +515,12 @@ def invoke(self, context: InvocationContext) -> ImageOutput: logger = context.logger logger.debug("Running NSFW checker") - if SafetyChecker.has_nsfw_concept(image): - logger.info("A potentially NSFW image has been detected. Image will be blurred.") - blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32)) - caution = self._get_caution_img() - blurry_image.paste(caution, (0, 0), caution) - image = blurry_image + image = SafetyChecker.blur_if_nsfw(image) image_dto = context.images.save(image=image) return ImageOutput.build(image_dto) - def _get_caution_img(self) -> Image.Image: - import invokeai.app.assets.images as image_assets - - caution = Image.open(Path(image_assets.__path__[0]) / "caution.png") - return caution.resize((caution.width // 2, caution.height // 2)) - @invocation( "img_watermark", diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py index 4e0bfe56e56..ab09a296197 100644 --- a/invokeai/backend/image_util/safety_checker.py +++ b/invokeai/backend/image_util/safety_checker.py @@ -65,7 +65,7 @@ def has_nsfw_concept(cls, image: Image.Image) -> bool: @classmethod def blur_if_nsfw(cls, image: Image.Image) -> Image.Image: if cls.has_nsfw_concept(image): - logger.info("A potentially NSFW image has been detected. Image will be blurred.") + logger.warning("A potentially NSFW image has been detected. Image will be blurred.") blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32)) caution = cls._get_caution_img() # Center the caution image on the blurred image