Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(backend): fix nsfw checker catch-22 #6360

Merged
merged 2 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions invokeai/app/api/routers/app_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from invokeai.app.invocations.upscale import ESRGAN_MODELS
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.util.logging import logging
from invokeai.version import __version__

Expand Down Expand Up @@ -109,9 +108,7 @@ async def get_config() -> AppConfig:
upscaling_models.append(str(Path(model).stem))
upscaler = Upscaler(upscaling_method="esrgan", upscaling_models=upscaling_models)

nsfw_methods = []
if SafetyChecker.safety_checker_available():
nsfw_methods.append("nsfw_checker")
nsfw_methods = ["nsfw_checker"]

watermarking_methods = ["invisible_watermark"]

Expand Down
16 changes: 2 additions & 14 deletions invokeai/app/invocations/image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)

from pathlib import Path
from typing import Literal, Optional

import cv2
Expand Down Expand Up @@ -504,7 +503,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
title="Blur NSFW Image",
tags=["image", "nsfw"],
category="image",
version="1.2.2",
version="1.2.3",
)
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add blur to NSFW-flagged images"""
Expand All @@ -516,23 +515,12 @@ def invoke(self, context: InvocationContext) -> ImageOutput:

logger = context.logger
logger.debug("Running NSFW checker")
if SafetyChecker.has_nsfw_concept(image):
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = self._get_caution_img()
blurry_image.paste(caution, (0, 0), caution)
image = blurry_image
image = SafetyChecker.blur_if_nsfw(image)

image_dto = context.images.save(image=image)

return ImageOutput.build(image_dto)

def _get_caution_img(self) -> Image.Image:
import invokeai.app.assets.images as image_assets

caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
return caution.resize((caution.width // 2, caution.height // 2))


@invocation(
"img_watermark",
Expand Down
48 changes: 35 additions & 13 deletions invokeai/backend/image_util/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@

import numpy as np
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from PIL import Image
from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor

import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.silence_warnings import SilenceWarnings

repo_id = "CompVis/stable-diffusion-safety-checker"
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"


Expand All @@ -24,30 +25,30 @@ class SafetyChecker:
Wrapper around SafetyChecker model.
"""

safety_checker = None
feature_extractor = None
tried_load: bool = False
safety_checker = None

@classmethod
def _load_safety_checker(cls):
if cls.tried_load:
if cls.safety_checker is not None and cls.feature_extractor is not None:
return

try:
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(get_config().models_path / CHECKER_PATH)
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(get_config().models_path / CHECKER_PATH)
model_path = get_config().models_path / CHECKER_PATH
if model_path.exists():
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(model_path)
else:
model_path.mkdir(parents=True, exist_ok=True)
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
cls.feature_extractor.save_pretrained(model_path, safe_serialization=True)
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(repo_id)
cls.safety_checker.save_pretrained(model_path, safe_serialization=True)
except Exception as e:
logger.warning(f"Could not load NSFW checker: {str(e)}")
cls.tried_load = True

@classmethod
def safety_checker_available(cls) -> bool:
return Path(get_config().models_path, CHECKER_PATH).exists()

@classmethod
def has_nsfw_concept(cls, image: Image.Image) -> bool:
if not cls.safety_checker_available() and cls.tried_load:
return False
cls._load_safety_checker()
if cls.safety_checker is None or cls.feature_extractor is None:
return False
Expand All @@ -60,3 +61,24 @@ def has_nsfw_concept(cls, image: Image.Image) -> bool:
with SilenceWarnings():
checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values)
return has_nsfw_concept[0]

@classmethod
def blur_if_nsfw(cls, image: Image.Image) -> Image.Image:
if cls.has_nsfw_concept(image):
logger.warning("A potentially NSFW image has been detected. Image will be blurred.")
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = cls._get_caution_img()
# Center the caution image on the blurred image
x = (blurry_image.width - caution.width) // 2
y = (blurry_image.height - caution.height) // 2
blurry_image.paste(caution, (x, y), caution)
image = blurry_image

return image

@classmethod
def _get_caution_img(cls) -> Image.Image:
import invokeai.app.assets.images as image_assets

caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
return caution.resize((caution.width // 2, caution.height // 2))
Loading