diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py
index b39922c69bf..188f958c887 100644
--- a/invokeai/app/api/sockets.py
+++ b/invokeai/app/api/sockets.py
@@ -20,8 +20,8 @@
DownloadStartedEvent,
FastAPIEvent,
InvocationCompleteEvent,
- InvocationDenoiseProgressEvent,
InvocationErrorEvent,
+ InvocationProgressEvent,
InvocationStartedEvent,
ModelEventBase,
ModelInstallCancelledEvent,
@@ -55,7 +55,7 @@ class BulkDownloadSubscriptionEvent(BaseModel):
QUEUE_EVENTS = {
InvocationStartedEvent,
- InvocationDenoiseProgressEvent,
+ InvocationProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
QueueItemStatusChangedEvent,
diff --git a/invokeai/app/invocations/spandrel_image_to_image.py b/invokeai/app/invocations/spandrel_image_to_image.py
index ae4f48ef77c..2da8694ede7 100644
--- a/invokeai/app/invocations/spandrel_image_to_image.py
+++ b/invokeai/app/invocations/spandrel_image_to_image.py
@@ -1,3 +1,4 @@
+import functools
from typing import Callable
import numpy as np
@@ -61,6 +62,7 @@ def upscale_image(
tile_size: int,
spandrel_model: SpandrelImageToImageModel,
is_canceled: Callable[[], bool],
+ step_callback: Callable[[int, int], None],
) -> Image.Image:
# Compute the image tiles.
if tile_size > 0:
@@ -103,7 +105,12 @@ def upscale_image(
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run the model on each tile.
- for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
+ pbar = tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles")
+
+ # Update progress, starting with 0.
+ step_callback(0, pbar.total)
+
+ for tile, scaled_tile in pbar:
# Exit early if the invocation has been canceled.
if is_canceled():
raise CanceledException
@@ -136,6 +143,8 @@ def upscale_image(
:,
] = output_tile[top_overlap:, left_overlap:, :]
+ step_callback(pbar.n + 1, pbar.total)
+
# Convert the output tensor to a PIL image.
np_image = output_tensor.detach().numpy().astype(np.uint8)
pil_image = Image.fromarray(np_image)
@@ -151,12 +160,20 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
+ def step_callback(step: int, total_steps: int) -> None:
+ context.util.signal_progress(
+ message=f"Processing image (tile {step}/{total_steps})",
+ percentage=step / total_steps,
+ )
+
# Do the upscaling.
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)
# Upscale the image
- pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
+ pil_image = self.upscale_image(
+ image, self.tile_size, spandrel_model, context.util.is_canceled, step_callback
+ )
image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)
@@ -197,12 +214,27 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
target_width = int(image.width * self.scale)
target_height = int(image.height * self.scale)
+ def step_callback(iteration: int, step: int, total_steps: int) -> None:
+ context.util.signal_progress(
+ message=self._get_progress_message(iteration, step, total_steps),
+ percentage=step / total_steps,
+ )
+
# Do the upscaling.
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)
+ iteration = 1
+ context.util.signal_progress(self._get_progress_message(iteration))
+
# First pass of upscaling. Note: `pil_image` will be mutated.
- pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
+ pil_image = self.upscale_image(
+ image,
+ self.tile_size,
+ spandrel_model,
+ context.util.is_canceled,
+ functools.partial(step_callback, iteration),
+ )
# Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
# upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions
@@ -211,16 +243,22 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
if is_upscale_model:
# This is an upscale model, so we should keep upscaling until we reach the target size.
- iterations = 1
while pil_image.width < target_width or pil_image.height < target_height:
- pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled)
- iterations += 1
+ iteration += 1
+ context.util.signal_progress(self._get_progress_message(iteration))
+ pil_image = self.upscale_image(
+ pil_image,
+ self.tile_size,
+ spandrel_model,
+ context.util.is_canceled,
+ functools.partial(step_callback, iteration),
+ )
# Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.
# Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations.
# We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice,
# we should never reach this limit.
- if iterations >= 5:
+ if iteration >= 5:
context.logger.warning(
"Upscale loop reached maximum iteration count of 5, stopping upscaling early."
)
@@ -251,3 +289,10 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)
+
+ @classmethod
+ def _get_progress_message(cls, iteration: int, step: int | None = None, total_steps: int | None = None) -> str:
+ if step is not None and total_steps is not None:
+ return f"Processing image (iteration {iteration}, tile {step}/{total_steps})"
+
+ return f"Processing image (iteration {iteration})"
diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py
index bb578c23e8c..681386e8770 100644
--- a/invokeai/app/services/events/events_base.py
+++ b/invokeai/app/services/events/events_base.py
@@ -15,8 +15,8 @@
DownloadStartedEvent,
EventBase,
InvocationCompleteEvent,
- InvocationDenoiseProgressEvent,
InvocationErrorEvent,
+ InvocationProgressEvent,
InvocationStartedEvent,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
@@ -30,13 +30,12 @@
QueueClearedEvent,
QueueItemStatusChangedEvent,
)
-from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
+from invokeai.app.services.session_processor.session_processor_common import ProgressImage
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
- from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
@@ -58,15 +57,16 @@ def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "B
"""Emitted when an invocation is started"""
self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
- def emit_invocation_denoise_progress(
+ def emit_invocation_progress(
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
- intermediate_state: PipelineIntermediateState,
- progress_image: "ProgressImage",
+ message: str,
+ percentage: float | None = None,
+ image: ProgressImage | None = None,
) -> None:
- """Emitted at each step during denoising of an invocation."""
- self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image))
+ """Emitted at each step during an invocation"""
+ self.dispatch(InvocationProgressEvent.build(queue_item, invocation, message, percentage, image))
def emit_invocation_complete(
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py
index c6a867fb081..ad84773d9c5 100644
--- a/invokeai/app/services/events/events_common.py
+++ b/invokeai/app/services/events/events_common.py
@@ -1,4 +1,3 @@
-from math import floor
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
from fastapi_events.handlers.local import local_handler
@@ -16,7 +15,6 @@
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
-from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
if TYPE_CHECKING:
from invokeai.app.services.download.download_base import DownloadJob
@@ -121,28 +119,28 @@ def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "Invo
@payload_schema.register
-class InvocationDenoiseProgressEvent(InvocationEventBase):
- """Event model for invocation_denoise_progress"""
+class InvocationProgressEvent(InvocationEventBase):
+ """Event model for invocation_progress"""
- __event_name__ = "invocation_denoise_progress"
+ __event_name__ = "invocation_progress"
- progress_image: ProgressImage = Field(description="The progress image sent at each step during processing")
- step: int = Field(description="The current step of the invocation")
- total_steps: int = Field(description="The total number of steps in the invocation")
- order: int = Field(description="The order of the invocation in the session")
- percentage: float = Field(description="The percentage of completion of the invocation")
+ message: str = Field(description="A message to display")
+ percentage: float | None = Field(
+ default=None, ge=0, le=1, description="The percentage of the progress (omit to indicate indeterminate progress)"
+ )
+ image: ProgressImage | None = Field(
+ default=None, description="An image representing the current state of the progress"
+ )
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: AnyInvocation,
- intermediate_state: PipelineIntermediateState,
- progress_image: ProgressImage,
- ) -> "InvocationDenoiseProgressEvent":
- step = intermediate_state.step
- total_steps = intermediate_state.total_steps
- order = intermediate_state.order
+ message: str,
+ percentage: float | None = None,
+ image: ProgressImage | None = None,
+ ) -> "InvocationProgressEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
@@ -150,23 +148,11 @@ def build(
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
- progress_image=progress_image,
- step=step,
- total_steps=total_steps,
- order=order,
- percentage=cls.calc_percentage(step, total_steps, order),
+ percentage=percentage,
+ image=image,
+ message=message,
)
- @staticmethod
- def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float:
- """Calculate the percentage of completion of denoising."""
- if total_steps == 0:
- return 0.0
- if scheduler_order == 2:
- return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
- # order == 1
- return (step + 1 + 1) / (total_steps + 1)
-
@payload_schema.register
class InvocationCompleteEvent(InvocationEventBase):
diff --git a/invokeai/app/services/session_processor/session_processor_common.py b/invokeai/app/services/session_processor/session_processor_common.py
index 0ca51de517c..346f12d8bbc 100644
--- a/invokeai/app/services/session_processor/session_processor_common.py
+++ b/invokeai/app/services/session_processor/session_processor_common.py
@@ -1,5 +1,8 @@
+from PIL.Image import Image as PILImageType
from pydantic import BaseModel, Field
+from invokeai.backend.util.util import image_to_dataURL
+
class SessionProcessorStatus(BaseModel):
is_started: bool = Field(description="Whether the session processor is started")
@@ -15,6 +18,16 @@ class CanceledException(Exception):
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
- width: int = Field(description="The effective width of the image in pixels")
- height: int = Field(description="The effective height of the image in pixels")
+ width: int = Field(ge=1, description="The effective width of the image in pixels")
+ height: int = Field(ge=1, description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")
+
+ @classmethod
+ def build(cls, image: PILImageType, size: tuple[int, int] | None = None) -> "ProgressImage":
+ """Build a ProgressImage from a PIL image"""
+
+ return cls(
+ width=size[0] if size else image.width,
+ height=size[1] if size else image.height,
+ dataURL=image_to_dataURL(image, image_format="JPEG"),
+ )
diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py
index 01662335e46..2e5b137c5f8 100644
--- a/invokeai/app/services/shared/invocation_context.py
+++ b/invokeai/app/services/shared/invocation_context.py
@@ -14,6 +14,7 @@
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
+from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import (
AnyModel,
@@ -550,13 +551,64 @@ def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_m
"""
stable_diffusion_step_callback(
- context_data=self._data,
+ signal_progress=self.signal_progress,
intermediate_state=intermediate_state,
base_model=base_model,
- events=self._services.events,
is_canceled=self.is_canceled,
)
+ def signal_progress(
+ self, message: str, percentage: float | None = None, image: ProgressImage | None = None
+ ) -> None:
+ """Signals the progress of some long-running invocation. The progress is displayed in the UI.
+
+ If you have an image to display, use `ProgressImage.build` to create the object.
+
+ If your progress image should be displayed at a different size, provide a tuple of `(width, height)` when
+ building the progress image.
+
+ For example, SD denoising progress images are 1/8 the size of the original image. In this case, the progress
+ image should be built like this to ensure it displays at the correct size:
+ ```py
+ progress_image = ProgressImage.build(image, (width * 8, height * 8))
+ ```
+
+ If your progress image is very large, consider downscaling it to reduce the payload size.
+
+ Example:
+ ```py
+ total_steps = 10
+ for i in range(total_steps):
+ # Do some iterative progressing
+ image = do_iterative_processing(image)
+
+ # Calculate the percentage
+ step = i + 1
+ percentage = step / total_steps
+
+ # Create a short, friendly message
+ message = f"Processing (step {step}/{total_steps})"
+
+ # Build the progress image
+ progress_image = ProgressImage.build(image)
+
+ # Send progress to the UI
+ context.util.signal_progress(message, percentage, progress_image)
+ ```
+
+ Args:
+ message: A message describing the current status.
+ percentage: The current percentage completion for the process. Omit for indeterminate progress.
+ image: An optional progress image to display.
+ """
+ self._services.events.emit_invocation_progress(
+ queue_item=self._data.queue_item,
+ invocation=self._data.invocation,
+ message=message,
+ percentage=percentage,
+ image=image,
+ )
+
class InvocationContext:
"""Provides access to various services and data for the current invocation.
diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py
index c0c101cd752..a5056c96171 100644
--- a/invokeai/app/util/step_callback.py
+++ b/invokeai/app/util/step_callback.py
@@ -1,4 +1,5 @@
-from typing import TYPE_CHECKING, Callable, Optional
+from math import floor
+from typing import Callable, Optional
import torch
from PIL import Image
@@ -6,11 +7,6 @@
from invokeai.app.services.session_processor.session_processor_common import CanceledException, ProgressImage
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
-from invokeai.backend.util.util import image_to_dataURL
-
-if TYPE_CHECKING:
- from invokeai.app.services.events.events_base import EventServiceBase
- from invokeai.app.services.shared.invocation_context import InvocationContextData
# fast latents preview matrix for sdxl
# generated by @StAlKeR7779
@@ -56,11 +52,25 @@ def sample_to_lowres_estimated_image(
return Image.fromarray(latents_ubyte.numpy())
+def calc_percentage(intermediate_state: PipelineIntermediateState) -> float:
+ """Calculate the percentage of completion of denoising."""
+
+ step = intermediate_state.step
+ total_steps = intermediate_state.total_steps
+ order = intermediate_state.order
+
+ if total_steps == 0:
+ return 0.0
+ if order == 2:
+ return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
+ # order == 1
+ return (step + 1 + 1) / (total_steps + 1)
+
+
def stable_diffusion_step_callback(
- context_data: "InvocationContextData",
+ signal_progress: Callable[[str, float | None, ProgressImage | None], None],
intermediate_state: PipelineIntermediateState,
base_model: BaseModelType,
- events: "EventServiceBase",
is_canceled: Callable[[], bool],
) -> None:
if is_canceled():
@@ -86,11 +96,10 @@ def stable_diffusion_step_callback(
width *= 8
height *= 8
- dataURL = image_to_dataURL(image, image_format="JPEG")
+ percentage = calc_percentage(intermediate_state)
- events.emit_invocation_denoise_progress(
- context_data.queue_item,
- context_data.invocation,
- intermediate_state,
- ProgressImage(dataURL=dataURL, width=width, height=height),
+ signal_progress(
+ "Denoising",
+ percentage,
+ ProgressImage.build(image=image, size=(width, height)),
)
diff --git a/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts b/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts
index f0ea175aec7..b6f1d53a2c5 100644
--- a/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts
+++ b/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts
@@ -3,7 +3,7 @@ import { deepClone } from 'common/util/deepClone';
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { Graph } from 'services/api/types';
-import { socketGeneratorProgress } from 'services/events/actions';
+import { socketInvocationProgress } from 'services/events/actions';
export const actionSanitizer = (action: A): A => {
if (isAnyGraphBuilt(action)) {
@@ -24,10 +24,10 @@ export const actionSanitizer = (action: A): A => {
};
}
- if (socketGeneratorProgress.match(action)) {
+ if (socketInvocationProgress.match(action)) {
const sanitized = deepClone(action);
- if (sanitized.payload.data.progress_image) {
- sanitized.payload.data.progress_image.dataURL = '