Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: generic progress events #6715

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions invokeai/app/api/sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
DownloadStartedEvent,
FastAPIEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationProgressEvent,
InvocationStartedEvent,
ModelEventBase,
ModelInstallCancelledEvent,
Expand Down Expand Up @@ -55,7 +55,7 @@ class BulkDownloadSubscriptionEvent(BaseModel):

QUEUE_EVENTS = {
InvocationStartedEvent,
InvocationDenoiseProgressEvent,
InvocationProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
QueueItemStatusChangedEvent,
Expand Down
59 changes: 52 additions & 7 deletions invokeai/app/invocations/spandrel_image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
from typing import Callable

import numpy as np
Expand Down Expand Up @@ -61,6 +62,7 @@ def upscale_image(
tile_size: int,
spandrel_model: SpandrelImageToImageModel,
is_canceled: Callable[[], bool],
step_callback: Callable[[int, int], None],
) -> Image.Image:
# Compute the image tiles.
if tile_size > 0:
Expand Down Expand Up @@ -103,7 +105,12 @@ def upscale_image(
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)

# Run the model on each tile.
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
pbar = tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles")

# Update progress, starting with 0.
step_callback(0, pbar.total)

for tile, scaled_tile in pbar:
# Exit early if the invocation has been canceled.
if is_canceled():
raise CanceledException
Expand Down Expand Up @@ -136,6 +143,8 @@ def upscale_image(
:,
] = output_tile[top_overlap:, left_overlap:, :]

step_callback(pbar.n + 1, pbar.total)

# Convert the output tensor to a PIL image.
np_image = output_tensor.detach().numpy().astype(np.uint8)
pil_image = Image.fromarray(np_image)
Expand All @@ -151,12 +160,20 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)

def step_callback(step: int, total_steps: int) -> None:
context.util.signal_progress(
message=f"Processing image (tile {step}/{total_steps})",
percentage=step / total_steps,
)

# Do the upscaling.
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)

# Upscale the image
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
pil_image = self.upscale_image(
image, self.tile_size, spandrel_model, context.util.is_canceled, step_callback
)

image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)
Expand Down Expand Up @@ -197,12 +214,27 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
target_width = int(image.width * self.scale)
target_height = int(image.height * self.scale)

def step_callback(iteration: int, step: int, total_steps: int) -> None:
context.util.signal_progress(
message=self._get_progress_message(iteration, step, total_steps),
percentage=step / total_steps,
)

# Do the upscaling.
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)

iteration = 1
context.util.signal_progress(self._get_progress_message(iteration))

# First pass of upscaling. Note: `pil_image` will be mutated.
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
pil_image = self.upscale_image(
image,
self.tile_size,
spandrel_model,
context.util.is_canceled,
functools.partial(step_callback, iteration),
)

# Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
# upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions
Expand All @@ -211,16 +243,22 @@ def invoke(self, context: InvocationContext) -> ImageOutput:

if is_upscale_model:
# This is an upscale model, so we should keep upscaling until we reach the target size.
iterations = 1
while pil_image.width < target_width or pil_image.height < target_height:
pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled)
iterations += 1
iteration += 1
context.util.signal_progress(self._get_progress_message(iteration))
pil_image = self.upscale_image(
pil_image,
self.tile_size,
spandrel_model,
context.util.is_canceled,
functools.partial(step_callback, iteration),
)

# Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.
# Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations.
# We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice,
# we should never reach this limit.
if iterations >= 5:
if iteration >= 5:
context.logger.warning(
"Upscale loop reached maximum iteration count of 5, stopping upscaling early."
)
Expand Down Expand Up @@ -251,3 +289,10 @@ def invoke(self, context: InvocationContext) -> ImageOutput:

image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)

@classmethod
def _get_progress_message(cls, iteration: int, step: int | None = None, total_steps: int | None = None) -> str:
if step is not None and total_steps is not None:
return f"Processing image (iteration {iteration}, tile {step}/{total_steps})"

return f"Processing image (iteration {iteration})"
16 changes: 8 additions & 8 deletions invokeai/app/services/events/events_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
DownloadStartedEvent,
EventBase,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationProgressEvent,
InvocationStartedEvent,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
Expand All @@ -30,13 +30,12 @@
QueueClearedEvent,
QueueItemStatusChangedEvent,
)
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.app.services.session_processor.session_processor_common import ProgressImage

if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
Expand All @@ -58,15 +57,16 @@ def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "B
"""Emitted when an invocation is started"""
self.dispatch(InvocationStartedEvent.build(queue_item, invocation))

def emit_invocation_denoise_progress(
def emit_invocation_progress(
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
intermediate_state: PipelineIntermediateState,
progress_image: "ProgressImage",
message: str,
percentage: float | None = None,
image: ProgressImage | None = None,
) -> None:
"""Emitted at each step during denoising of an invocation."""
self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image))
"""Emitted at each step during an invocation"""
self.dispatch(InvocationProgressEvent.build(queue_item, invocation, message, percentage, image))

def emit_invocation_complete(
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
Expand Down
48 changes: 17 additions & 31 deletions invokeai/app/services/events/events_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from math import floor
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar

from fastapi_events.handlers.local import local_handler
Expand All @@ -16,7 +15,6 @@
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState

if TYPE_CHECKING:
from invokeai.app.services.download.download_base import DownloadJob
Expand Down Expand Up @@ -121,52 +119,40 @@ def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "Invo


@payload_schema.register
class InvocationDenoiseProgressEvent(InvocationEventBase):
"""Event model for invocation_denoise_progress"""
class InvocationProgressEvent(InvocationEventBase):
"""Event model for invocation_progress"""

__event_name__ = "invocation_denoise_progress"
__event_name__ = "invocation_progress"

progress_image: ProgressImage = Field(description="The progress image sent at each step during processing")
step: int = Field(description="The current step of the invocation")
total_steps: int = Field(description="The total number of steps in the invocation")
order: int = Field(description="The order of the invocation in the session")
percentage: float = Field(description="The percentage of completion of the invocation")
message: str = Field(description="A message to display")
percentage: float | None = Field(
default=None, ge=0, le=1, description="The percentage of the progress (omit to indicate indeterminate progress)"
)
image: ProgressImage | None = Field(
default=None, description="An image representing the current state of the progress"
)

@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: AnyInvocation,
intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage,
) -> "InvocationDenoiseProgressEvent":
step = intermediate_state.step
total_steps = intermediate_state.total_steps
order = intermediate_state.order
message: str,
percentage: float | None = None,
image: ProgressImage | None = None,
) -> "InvocationProgressEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
progress_image=progress_image,
step=step,
total_steps=total_steps,
order=order,
percentage=cls.calc_percentage(step, total_steps, order),
percentage=percentage,
image=image,
message=message,
)

@staticmethod
def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float:
"""Calculate the percentage of completion of denoising."""
if total_steps == 0:
return 0.0
if scheduler_order == 2:
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
# order == 1
return (step + 1 + 1) / (total_steps + 1)


@payload_schema.register
class InvocationCompleteEvent(InvocationEventBase):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from PIL.Image import Image as PILImageType
from pydantic import BaseModel, Field

from invokeai.backend.util.util import image_to_dataURL


class SessionProcessorStatus(BaseModel):
is_started: bool = Field(description="Whether the session processor is started")
Expand All @@ -15,6 +18,16 @@ class CanceledException(Exception):
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""

width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
width: int = Field(ge=1, description="The effective width of the image in pixels")
height: int = Field(ge=1, description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")

@classmethod
def build(cls, image: PILImageType, size: tuple[int, int] | None = None) -> "ProgressImage":
"""Build a ProgressImage from a PIL image"""

return cls(
width=size[0] if size else image.width,
height=size[1] if size else image.height,
dataURL=image_to_dataURL(image, image_format="JPEG"),
)
56 changes: 54 additions & 2 deletions invokeai/app/services/shared/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import (
AnyModel,
Expand Down Expand Up @@ -550,13 +551,64 @@ def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_m
"""

stable_diffusion_step_callback(
context_data=self._data,
signal_progress=self.signal_progress,
intermediate_state=intermediate_state,
base_model=base_model,
events=self._services.events,
is_canceled=self.is_canceled,
)

def signal_progress(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should called update_status instead?

self, message: str, percentage: float | None = None, image: ProgressImage | None = None
) -> None:
"""Signals the progress of some long-running invocation. The progress is displayed in the UI.

If you have an image to display, use `ProgressImage.build` to create the object.

If your progress image should be displayed at a different size, provide a tuple of `(width, height)` when
building the progress image.

For example, SD denoising progress images are 1/8 the size of the original image. In this case, the progress
image should be built like this to ensure it displays at the correct size:
```py
progress_image = ProgressImage.build(image, (width * 8, height * 8))
```

If your progress image is very large, consider downscaling it to reduce the payload size.

Example:
```py
total_steps = 10
for i in range(total_steps):
# Do some iterative progressing
image = do_iterative_processing(image)

# Calculate the percentage
step = i + 1
percentage = step / total_steps

# Create a short, friendly message
message = f"Processing (step {step}/{total_steps})"

# Build the progress image
progress_image = ProgressImage.build(image)

# Send progress to the UI
context.util.signal_progress(message, percentage, progress_image)
```

Args:
message: A message describing the current status.
percentage: The current percentage completion for the process. Omit for indeterminate progress.
image: An optional progress image to display.
"""
self._services.events.emit_invocation_progress(
queue_item=self._data.queue_item,
invocation=self._data.invocation,
message=message,
percentage=percentage,
image=image,
)


class InvocationContext:
"""Provides access to various services and data for the current invocation.
Expand Down
Loading
Loading