Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: generic progress events #6908

Merged
merged 10 commits into from
Sep 22, 2024
4 changes: 2 additions & 2 deletions invokeai/app/api/sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
DownloadStartedEvent,
FastAPIEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationProgressEvent,
InvocationStartedEvent,
ModelEventBase,
ModelInstallCancelledEvent,
Expand Down Expand Up @@ -55,7 +55,7 @@ class BulkDownloadSubscriptionEvent(BaseModel):

QUEUE_EVENTS = {
InvocationStartedEvent,
InvocationDenoiseProgressEvent,
InvocationProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
QueueItemStatusChangedEvent,
Expand Down
59 changes: 52 additions & 7 deletions invokeai/app/invocations/spandrel_image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from typing import Callable

import numpy as np
Expand Down Expand Up @@ -61,6 +62,7 @@ def upscale_image(
tile_size: int,
spandrel_model: SpandrelImageToImageModel,
is_canceled: Callable[[], bool],
step_callback: Callable[[int, int], None],
) -> Image.Image:
# Compute the image tiles.
if tile_size > 0:
Expand Down Expand Up @@ -103,7 +105,12 @@ def upscale_image(
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)

# Run the model on each tile.
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
pbar = tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles")

# Update progress, starting with 0.
step_callback(0, pbar.total)

for tile, scaled_tile in pbar:
# Exit early if the invocation has been canceled.
if is_canceled():
raise CanceledException
Expand Down Expand Up @@ -136,6 +143,8 @@ def upscale_image(
:,
] = output_tile[top_overlap:, left_overlap:, :]

step_callback(pbar.n + 1, pbar.total)

# Convert the output tensor to a PIL image.
np_image = output_tensor.detach().numpy().astype(np.uint8)
pil_image = Image.fromarray(np_image)
Expand All @@ -151,12 +160,20 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)

def step_callback(step: int, total_steps: int) -> None:
context.util.signal_progress(
message=f"Processing tile {step}/{total_steps}",
percentage=step / total_steps,
)

# Do the upscaling.
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)

# Upscale the image
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
pil_image = self.upscale_image(
image, self.tile_size, spandrel_model, context.util.is_canceled, step_callback
)

image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)
Expand Down Expand Up @@ -197,12 +214,27 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
target_width = int(image.width * self.scale)
target_height = int(image.height * self.scale)

def step_callback(iteration: int, step: int, total_steps: int) -> None:
context.util.signal_progress(
message=self._get_progress_message(iteration, step, total_steps),
percentage=step / total_steps,
)

# Do the upscaling.
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)

iteration = 1
context.util.signal_progress(self._get_progress_message(iteration))

# First pass of upscaling. Note: `pil_image` will be mutated.
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
pil_image = self.upscale_image(
image,
self.tile_size,
spandrel_model,
context.util.is_canceled,
functools.partial(step_callback, iteration),
)

# Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
# upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions
Expand All @@ -211,16 +243,22 @@ def invoke(self, context: InvocationContext) -> ImageOutput:

if is_upscale_model:
# This is an upscale model, so we should keep upscaling until we reach the target size.
iterations = 1
while pil_image.width < target_width or pil_image.height < target_height:
pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled)
iterations += 1
iteration += 1
context.util.signal_progress(self._get_progress_message(iteration))
pil_image = self.upscale_image(
pil_image,
self.tile_size,
spandrel_model,
context.util.is_canceled,
functools.partial(step_callback, iteration),
)

# Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.
# Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations.
# We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice,
# we should never reach this limit.
if iterations >= 5:
if iteration >= 5:
context.logger.warning(
"Upscale loop reached maximum iteration count of 5, stopping upscaling early."
)
Expand Down Expand Up @@ -251,3 +289,10 @@ def invoke(self, context: InvocationContext) -> ImageOutput:

image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)

@classmethod
def _get_progress_message(cls, iteration: int, step: int | None = None, total_steps: int | None = None) -> str:
if step is not None and total_steps is not None:
return f"Processing iteration {iteration}, tile {step}/{total_steps}"

return f"Processing iteration {iteration}"
14 changes: 7 additions & 7 deletions invokeai/app/services/events/events_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
DownloadStartedEvent,
EventBase,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationProgressEvent,
InvocationStartedEvent,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
Expand All @@ -30,7 +30,6 @@
QueueClearedEvent,
QueueItemStatusChangedEvent,
)
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState

if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
Expand Down Expand Up @@ -58,15 +57,16 @@ def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "B
"""Emitted when an invocation is started"""
self.dispatch(InvocationStartedEvent.build(queue_item, invocation))

def emit_invocation_denoise_progress(
def emit_invocation_progress(
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
intermediate_state: PipelineIntermediateState,
progress_image: "ProgressImage",
message: str,
percentage: float | None = None,
image: "ProgressImage | None" = None,
) -> None:
"""Emitted at each step during denoising of an invocation."""
self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image))
"""Emitted at periodically during an invocation"""
self.dispatch(InvocationProgressEvent.build(queue_item, invocation, message, percentage, image))

def emit_invocation_complete(
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
Expand Down
48 changes: 17 additions & 31 deletions invokeai/app/services/events/events_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from math import floor
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar

from fastapi_events.handlers.local import local_handler
Expand All @@ -16,7 +15,6 @@
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState

if TYPE_CHECKING:
from invokeai.app.services.download.download_base import DownloadJob
Expand Down Expand Up @@ -123,28 +121,28 @@ def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "Invo


@payload_schema.register
class InvocationDenoiseProgressEvent(InvocationEventBase):
"""Event model for invocation_denoise_progress"""
class InvocationProgressEvent(InvocationEventBase):
"""Event model for invocation_progress"""

__event_name__ = "invocation_denoise_progress"
__event_name__ = "invocation_progress"

progress_image: ProgressImage = Field(description="The progress image sent at each step during processing")
step: int = Field(description="The current step of the invocation")
total_steps: int = Field(description="The total number of steps in the invocation")
order: int = Field(description="The order of the invocation in the session")
percentage: float = Field(description="The percentage of completion of the invocation")
message: str = Field(description="A message to display")
percentage: float | None = Field(
default=None, ge=0, le=1, description="The percentage of the progress (omit to indicate indeterminate progress)"
)
image: ProgressImage | None = Field(
default=None, description="An image representing the current state of the progress"
)

@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: AnyInvocation,
intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage,
) -> "InvocationDenoiseProgressEvent":
step = intermediate_state.step
total_steps = intermediate_state.total_steps
order = intermediate_state.order
message: str,
percentage: float | None = None,
image: ProgressImage | None = None,
) -> "InvocationProgressEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
Expand All @@ -154,23 +152,11 @@ def build(
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
progress_image=progress_image,
step=step,
total_steps=total_steps,
order=order,
percentage=cls.calc_percentage(step, total_steps, order),
percentage=percentage,
image=image,
message=message,
)

@staticmethod
def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float:
"""Calculate the percentage of completion of denoising."""
if total_steps == 0:
return 0.0
if scheduler_order == 2:
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
# order == 1
return (step + 1 + 1) / (total_steps + 1)


@payload_schema.register
class InvocationCompleteEvent(InvocationEventBase):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from PIL.Image import Image as PILImageType
from pydantic import BaseModel, Field

from invokeai.backend.util.util import image_to_dataURL


class SessionProcessorStatus(BaseModel):
is_started: bool = Field(description="Whether the session processor is started")
Expand All @@ -15,6 +18,16 @@ class CanceledException(Exception):
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""

width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
width: int = Field(ge=1, description="The effective width of the image in pixels")
height: int = Field(ge=1, description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")

@classmethod
def build(cls, image: PILImageType, size: tuple[int, int] | None = None) -> "ProgressImage":
"""Build a ProgressImage from a PIL image"""

return cls(
width=size[0] if size else image.width,
height=size[1] if size else image.height,
dataURL=image_to_dataURL(image, image_format="JPEG"),
)
69 changes: 65 additions & 4 deletions invokeai/app/services/shared/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.util.step_callback import flux_step_callback, stable_diffusion_step_callback
from invokeai.backend.model_manager.config import (
AnyModel,
Expand Down Expand Up @@ -550,10 +551,9 @@ def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_m
"""

stable_diffusion_step_callback(
context_data=self._data,
signal_progress=self.signal_progress,
intermediate_state=intermediate_state,
base_model=base_model,
events=self._services.events,
is_canceled=self.is_canceled,
)

Expand All @@ -569,12 +569,73 @@ def flux_step_callback(self, intermediate_state: PipelineIntermediateState) -> N
"""

flux_step_callback(
context_data=self._data,
signal_progress=self.signal_progress,
intermediate_state=intermediate_state,
events=self._services.events,
is_canceled=self.is_canceled,
)

def signal_progress(
self,
message: str,
percentage: float | None = None,
image: Image | None = None,
image_size: tuple[int, int] | None = None,
) -> None:
"""Signals the progress of some long-running invocation. The progress is displayed in the UI.

If a percentage is provided, the UI will display a progress bar and automatically append the percentage to the
message. You should not include the percentage in the message.

Example:
```py
total_steps = 10
for i in range(total_steps):
percentage = i / (total_steps - 1)
context.util.signal_progress("Doing something cool", percentage)
```

If an image is provided, the UI will display it. If your image should be displayed at a different size, provide
a tuple of `(width, height)` for the `image_size` parameter. The image will be displayed at the specified size
in the UI.

For example, SD denoising progress images are 1/8 the size of the original image, so you'd do this to ensure the
image is displayed at the correct size:
```py
# Calculate the output size of the image (8x the progress image's size)
width = progress_image.width * 8
height = progress_image.height * 8
# Signal the progress with the image and output size
signal_progress("Denoising", percentage, progress_image, (width, height))
```

If your progress image is very large, consider downscaling it to reduce the payload size and provide the original
size to the `image_size` parameter. The PIL `thumbnail` method is useful for this, as it maintains the aspect
ratio of the image:
```py
# `thumbnail` modifies the image in-place, so we need to first make a copy
thumbnail_image = progress_image.copy()
# Resize the image to a maximum of 256x256 pixels, maintaining the aspect ratio
thumbnail_image.thumbnail((256, 256))
# Signal the progress with the thumbnail, passing the original size
signal_progress("Denoising", percentage, thumbnail, progress_image.size)
```

Args:
message: A message describing the current status. Do not include the percentage in this message.
percentage: The current percentage completion for the process. Omit for indeterminate progress.
image: An optional image to display.
image_size: The optional size of the image to display. If omitted, the image will be displayed at its
original size.
"""

self._services.events.emit_invocation_progress(
queue_item=self._data.queue_item,
invocation=self._data.invocation,
message=message,
percentage=percentage,
image=ProgressImage.build(image, image_size) if image else None,
)


class InvocationContext:
"""Provides access to various services and data for the current invocation.
Expand Down
Loading
Loading