Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(nodes): better controlnet processors #6021

Merged
merged 11 commits into from
Mar 21, 2024
55 changes: 33 additions & 22 deletions invokeai/app/invocations/controlnet_image_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,8 @@
import cv2
import numpy as np
from controlnet_aux import (
CannyDetector,
ContentShuffleDetector,
HEDdetector,
LeresDetector,
LineartAnimeDetector,
LineartDetector,
MediapipeFaceDetector,
MidasDetector,
MLSDdetector,
Expand All @@ -39,8 +35,12 @@
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor

from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output

Expand Down Expand Up @@ -171,11 +171,12 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
title="Canny Processor",
tags=["controlnet", "canny"],
category="controlnet",
version="1.3.1",
version="1.3.2",
)
class CannyImageProcessorInvocation(ImageProcessorInvocation):
"""Canny edge detection for ControlNet"""

detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
low_threshold: int = InputField(
default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)"
Expand All @@ -188,12 +189,12 @@ def load_image(self, context: InvocationContext) -> Image.Image:
# Keep alpha channel for Canny processing to detect edges of transparent areas
return context.images.get_pil(self.image.image_name, "RGBA")

def run_processor(self, image):
canny_processor = CannyDetector()
processed_image = canny_processor(
def run_processor(self, image: Image.Image) -> Image.Image:
processed_image = get_canny_edges(
image,
self.low_threshold,
self.high_threshold,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
)
return processed_image
Expand All @@ -215,9 +216,9 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
# safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode)
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)

def run_processor(self, image):
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = hed_processor(
def run_processor(self, image: Image.Image) -> Image.Image:
hed_processor = HEDProcessor()
processed_image = hed_processor.run(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
Expand All @@ -242,9 +243,9 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
coarse: bool = InputField(default=False, description="Whether to use coarse mode")

def run_processor(self, image):
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
processed_image = lineart_processor(
def run_processor(self, image: Image.Image) -> Image.Image:
lineart_processor = LineartProcessor()
processed_image = lineart_processor.run(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
)
return processed_image
Expand All @@ -263,9 +264,9 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)

def run_processor(self, image):
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = processor(
def run_processor(self, image: Image.Image) -> Image.Image:
processor = LineartAnimeProcessor()
processed_image = processor.run(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
Expand All @@ -278,13 +279,14 @@ def run_processor(self, image):
title="Midas Depth Processor",
tags=["controlnet", "midas"],
category="controlnet",
version="1.2.2",
version="1.2.3",
)
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Midas depth processing to image"""

a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)")
bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`")
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
# depth_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
Expand All @@ -296,6 +298,7 @@ def run_processor(self, image):
a=np.pi * self.a_mult,
bg_th=self.bg_th,
image_resolution=self.image_resolution,
detect_resolution=self.detect_resolution,
# dept_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal=self.depth_and_normal,
)
Expand Down Expand Up @@ -420,19 +423,24 @@ def run_processor(self, image):
title="Mediapipe Face Processor",
tags=["controlnet", "mediapipe", "face"],
category="controlnet",
version="1.2.2",
version="1.2.3",
)
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
"""Applies mediapipe face processing to image"""

max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect")
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)

def run_processor(self, image):
mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(
image, max_faces=self.max_faces, min_confidence=self.min_confidence, image_resolution=self.image_resolution
image,
max_faces=self.max_faces,
min_confidence=self.min_confidence,
image_resolution=self.image_resolution,
detect_resolution=self.detect_resolution,
)
return processed_image

Expand Down Expand Up @@ -511,11 +519,12 @@ def run_processor(self, img):
title="Segment Anything Processor",
tags=["controlnet", "segmentanything"],
category="controlnet",
version="1.2.2",
version="1.2.3",
)
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
"""Applies segment anything processing to image"""

detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)

def run_processor(self, image):
Expand All @@ -524,7 +533,9 @@ def run_processor(self, image):
"ybelkada/segment-anything", subfolder="checkpoints"
)
np_img = np.array(image, dtype=np.uint8)
processed_image = segment_anything_processor(np_img, image_resolution=self.image_resolution)
processed_image = segment_anything_processor(
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
)
return processed_image


Expand Down
41 changes: 41 additions & 0 deletions invokeai/backend/image_util/canny.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import cv2
from PIL import Image

from invokeai.backend.image_util.util import (
cv2_to_pil,
normalize_image_channel_count,
pil_to_cv2,
resize_image_to_resolution,
)


def get_canny_edges(
image: Image.Image, low_threshold: int, high_threshold: int, detect_resolution: int, image_resolution: int
) -> Image.Image:
"""Returns the edges of an image using the Canny edge detection algorithm.

Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license).

Args:
image: The input image.
low_threshold: The lower threshold for the hysteresis procedure.
high_threshold: The upper threshold for the hysteresis procedure.
input_resolution: The resolution of the input image. The image will be resized to this resolution before edge detection.
output_resolution: The resolution of the output image. The edges will be resized to this resolution before returning.

Returns:
The Canny edges of the input image.
"""

if image.mode != "RGB":
image = image.convert("RGB")

np_image = pil_to_cv2(image)
np_image = normalize_image_channel_count(np_image)
np_image = resize_image_to_resolution(np_image, detect_resolution)

edge_map = cv2.Canny(np_image, low_threshold, high_threshold)
edge_map = normalize_image_channel_count(edge_map)
edge_map = resize_image_to_resolution(edge_map, image_resolution)

return cv2_to_pil(edge_map)
142 changes: 142 additions & 0 deletions invokeai/backend/image_util/hed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""

import cv2
import numpy as np
import torch
from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image

from invokeai.backend.image_util.util import (
non_maximum_suppression,
normalize_image_channel_count,
np_to_pil,
pil_to_np,
resize_image_to_resolution,
safe_step,
)


class DoubleConvBlock(torch.nn.Module):
def __init__(self, input_channel, output_channel, layer_number):
super().__init__()
self.convs = torch.nn.Sequential()
self.convs.append(
torch.nn.Conv2d(
in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1
)
)
for _i in range(1, layer_number):
self.convs.append(
torch.nn.Conv2d(
in_channels=output_channel,
out_channels=output_channel,
kernel_size=(3, 3),
stride=(1, 1),
padding=1,
)
)
self.projection = torch.nn.Conv2d(
in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0
)

def __call__(self, x, down_sampling=False):
h = x
if down_sampling:
h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
for conv in self.convs:
h = conv(h)
h = torch.nn.functional.relu(h)
return h, self.projection(h)


class ControlNetHED_Apache2(torch.nn.Module):
def __init__(self):
super().__init__()
self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)

def __call__(self, x):
h = x - self.norm
h, projection1 = self.block1(h)
h, projection2 = self.block2(h, down_sampling=True)
h, projection3 = self.block3(h, down_sampling=True)
h, projection4 = self.block4(h, down_sampling=True)
h, projection5 = self.block5(h, down_sampling=True)
return projection1, projection2, projection3, projection4, projection5


class HEDProcessor:
"""Holistically-Nested Edge Detection.

On instantiation, loads the HED model from the HuggingFace Hub.
"""

def __init__(self):
model_path = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth")
self.network = ControlNetHED_Apache2()
self.network.load_state_dict(torch.load(model_path, map_location="cpu"))
self.network.float().eval()

def to(self, device: torch.device):
self.network.to(device)
return self

def run(
self,
input_image: Image.Image,
detect_resolution: int = 512,
image_resolution: int = 512,
safe: bool = False,
scribble: bool = False,
) -> Image.Image:
"""Processes an image and returns the detected edges.

Args:
input_image: The input image.
detect_resolution: The resolution to fit the image to before edge detection.
image_resolution: The resolution to fit the edges to before returning.
safe: Whether to apply safe step to the detected edges.
scribble: Whether to apply non-maximum suppression and Gaussian blur to the detected edges.

Returns:
The detected edges.
"""
device = next(iter(self.network.parameters())).device
np_image = pil_to_np(input_image)
np_image = normalize_image_channel_count(np_image)
np_image = resize_image_to_resolution(np_image, detect_resolution)

assert np_image.ndim == 3
height, width, _channels = np_image.shape
with torch.no_grad():
image_hed = torch.from_numpy(np_image.copy()).float().to(device)
image_hed = rearrange(image_hed, "h w c -> 1 c h w")
edges = self.network(image_hed)
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
edges = [cv2.resize(e, (width, height), interpolation=cv2.INTER_LINEAR) for e in edges]
edges = np.stack(edges, axis=2)
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
if safe:
edge = safe_step(edge)
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)

detected_map = edge
detected_map = normalize_image_channel_count(detected_map)

img = resize_image_to_resolution(np_image, image_resolution)
height, width, _channels = img.shape

detected_map = cv2.resize(detected_map, (width, height), interpolation=cv2.INTER_LINEAR)

if scribble:
detected_map = non_maximum_suppression(detected_map, 127, 3.0)
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
detected_map[detected_map > 4] = 255
detected_map[detected_map < 255] = 0

return np_to_pil(detected_map)
Loading
Loading