diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index a9132516a86..317ef15b170 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -17,6 +17,7 @@ from ..services.boards.boards_default import BoardService from ..services.config import InvokeAIAppConfig from ..services.download import DownloadQueueService +from ..services.events.events_fastapievents import FastAPIEventService from ..services.image_files.image_files_disk import DiskImageFileStorage from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage from ..services.images.images_default import ImageService @@ -32,7 +33,6 @@ from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue from ..services.urls.urls_default import LocalUrlService from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage -from .events import FastAPIEventService # TODO: is there a better way to achieve this? diff --git a/invokeai/app/api/events.py b/invokeai/app/api/events.py deleted file mode 100644 index 2ac07e6dfe3..00000000000 --- a/invokeai/app/api/events.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -import asyncio -import threading -from queue import Empty, Queue -from typing import Any - -from fastapi_events.dispatcher import dispatch - -from ..services.events.events_base import EventServiceBase - - -class FastAPIEventService(EventServiceBase): - event_handler_id: int - __queue: Queue - __stop_event: threading.Event - - def __init__(self, event_handler_id: int) -> None: - self.event_handler_id = event_handler_id - self.__queue = Queue() - self.__stop_event = threading.Event() - asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event)) - - super().__init__() - - def stop(self, *args, **kwargs): - self.__stop_event.set() - self.__queue.put(None) - - def dispatch(self, event_name: str, payload: Any) -> None: - self.__queue.put({"event_name": event_name, "payload": payload}) - - async def __dispatch_from_queue(self, stop_event: threading.Event): - """Get events on from the queue and dispatch them, from the correct thread""" - while not stop_event.is_set(): - try: - event = self.__queue.get(block=False) - if not event: # Probably stopping - continue - - dispatch( - event.get("event_name"), - payload=event.get("payload"), - middleware_id=self.event_handler_id, - ) - - except Empty: - await asyncio.sleep(0.1) - pass - - except asyncio.CancelledError as e: - raise e # Raise a proper error diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index e651e435591..163378053c5 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -1,11 +1,38 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +from typing import Any + from fastapi import FastAPI -from fastapi_events.handlers.local import local_handler -from fastapi_events.typing import Event +from pydantic import BaseModel from socketio import ASGIApp, AsyncServer -from ..services.events.events_base import EventServiceBase +from invokeai.app.services.events.events_common import ( + BatchEnqueuedEvent, + FastAPIEvent, + InvocationCompleteEvent, + InvocationDenoiseProgressEvent, + InvocationErrorEvent, + InvocationStartedEvent, + ModelEvent, + ModelInstalLCancelledEvent, + ModelInstallCompletedEvent, + ModelInstallDownloadProgressEvent, + ModelInstallErrorEvent, + ModelInstallStartedEvent, + ModelLoadCompleteEvent, + ModelLoadStartedEvent, + QueueClearedEvent, + QueueEvent, + QueueItemStatusChangedEvent, + SessionCanceledEvent, + SessionCompleteEvent, + SessionStartedEvent, + register_events, +) + + +class QueueSubscriptionEvent(BaseModel): + queue_id: str class SocketIO: @@ -19,23 +46,46 @@ def __init__(self, app: FastAPI): self.__sio.on("subscribe_queue", handler=self._handle_sub_queue) self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue) - local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event) - local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event) - - async def _handle_queue_event(self, event: Event): - await self.__sio.emit( - event=event[1]["event"], - data=event[1]["data"], - room=event[1]["data"]["queue_id"], + + register_events( + [ + InvocationStartedEvent, + InvocationDenoiseProgressEvent, + InvocationCompleteEvent, + InvocationErrorEvent, + SessionStartedEvent, + SessionCompleteEvent, + SessionCanceledEvent, + QueueItemStatusChangedEvent, + BatchEnqueuedEvent, + QueueClearedEvent, + ], + self._handle_queue_event, ) - async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None: - if "queue_id" in data: - await self.__sio.enter_room(sid, data["queue_id"]) + register_events( + [ + ModelLoadStartedEvent, + ModelLoadCompleteEvent, + ModelInstallDownloadProgressEvent, + ModelInstallStartedEvent, + ModelInstallCompletedEvent, + ModelInstalLCancelledEvent, + ModelInstallErrorEvent, + ], + self._handle_model_event, + ) + + async def _handle_sub_queue(self, sid: str, data: Any) -> None: + await self.__sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id) + + async def _handle_unsub_queue(self, sid: str, data: Any) -> None: + await self.__sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id) - async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None: - if "queue_id" in data: - await self.__sio.leave_room(sid, data["queue_id"]) + async def _handle_queue_event(self, event: FastAPIEvent[QueueEvent]): + event_name, payload = event + await self.__sio.emit(event=event_name, data=payload.model_dump(), room=payload.queue_id) - async def _handle_model_event(self, event: Event) -> None: - await self.__sio.emit(event=event[1]["event"], data=event[1]["data"]) + async def _handle_model_event(self, event: FastAPIEvent[ModelEvent]) -> None: + event_name, payload = event + await self.__sio.emit(event=event_name, data=payload.model_dump()) diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index f6b08ddba66..cd5f3e3ac76 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -2,6 +2,9 @@ # which are imported/used before parse_args() is called will get the default config values instead of the # values from the command line or config file. import sys +from typing import cast + +from pydantic import BaseModel from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.version.invokeai_version import __version__ @@ -32,6 +35,7 @@ from fastapi.responses import HTMLResponse from fastapi_events.handlers.local import local_handler from fastapi_events.middleware import EventHandlerASGIMiddleware + from fastapi_events.registry.payload_schema import registry as fastapi_events_registry from pydantic.json_schema import models_json_schema from torch.backends.mps import is_available as is_mps_available @@ -173,23 +177,16 @@ def custom_openapi() -> dict[str, Any]: invoker_schema["output"] = outputs_ref invoker_schema["class"] = "invocation" - # This code no longer seems to be necessary? - # Leave it here just in case - # - # from invokeai.backend.model_manager import get_model_config_formats - # formats = get_model_config_formats() - # for model_config_name, enum_set in formats.items(): - - # if model_config_name in openapi_schema["components"]["schemas"]: - # # print(f"Config with name {name} already defined") - # continue - - # openapi_schema["components"]["schemas"][model_config_name] = { - # "title": model_config_name, - # "description": "An enumeration.", - # "type": "string", - # "enum": [v.value for v in enum_set], - # } + # Add all pydantic event schemas registered with fastapi-events + for payload in fastapi_events_registry.data.values(): + json_schema = cast(BaseModel, payload).model_json_schema( + mode="serialization", ref_template="#/components/schemas/{model}" + ) + if "$defs" in json_schema: + for schema_key, schema in json_schema["$defs"].items(): + openapi_schema["components"]["schemas"][schema_key] = schema + del json_schema["$defs"] + openapi_schema["components"]["schemas"][payload.__name__] = json_schema app.openapi_schema = openapi_schema return app.openapi_schema diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 5355fe22987..323fc65e728 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -1,432 +1,12 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from typing import Any, Dict, List, Optional, Union - -from invokeai.app.services.session_processor.session_processor_common import ProgressImage -from invokeai.app.services.session_queue.session_queue_common import ( - BatchStatus, - EnqueueBatchResult, - SessionQueueItem, - SessionQueueStatus, -) -from invokeai.app.util.misc import get_timestamp -from invokeai.backend.model_manager import AnyModelConfig +from invokeai.app.services.events.events_common import AppEvent class EventServiceBase: - queue_event: str = "queue_event" - download_event: str = "download_event" - model_event: str = "model_event" """Basic event bus, to have an empty stand-in when not needed""" - def dispatch(self, event_name: str, payload: Any) -> None: + def dispatch(self, event: AppEvent) -> None: pass - - def __emit_queue_event(self, event_name: str, payload: dict) -> None: - """Queue events are emitted to a room with queue_id as the room name""" - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.queue_event, - payload={"event": event_name, "data": payload}, - ) - - def __emit_download_event(self, event_name: str, payload: dict) -> None: - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.download_event, - payload={"event": event_name, "data": payload}, - ) - - def __emit_model_event(self, event_name: str, payload: dict) -> None: - payload["timestamp"] = get_timestamp() - self.dispatch( - event_name=EventServiceBase.model_event, - payload={"event": event_name, "data": payload}, - ) - - # Define events here for every event in the system. - # This will make them easier to integrate until we find a schema generator. - def emit_generator_progress( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - node_id: str, - source_node_id: str, - progress_image: Optional[ProgressImage], - step: int, - order: int, - total_steps: int, - ) -> None: - """Emitted when there is generation progress""" - self.__emit_queue_event( - event_name="generator_progress", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node_id": node_id, - "source_node_id": source_node_id, - "progress_image": progress_image.model_dump() if progress_image is not None else None, - "step": step, - "order": order, - "total_steps": total_steps, - }, - ) - - def emit_invocation_complete( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - result: dict, - node: dict, - source_node_id: str, - ) -> None: - """Emitted when an invocation has completed""" - self.__emit_queue_event( - event_name="invocation_complete", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node": node, - "source_node_id": source_node_id, - "result": result, - }, - ) - - def emit_invocation_error( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - node: dict, - source_node_id: str, - error_type: str, - error: str, - ) -> None: - """Emitted when an invocation has completed""" - self.__emit_queue_event( - event_name="invocation_error", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node": node, - "source_node_id": source_node_id, - "error_type": error_type, - "error": error, - }, - ) - - def emit_invocation_started( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - node: dict, - source_node_id: str, - ) -> None: - """Emitted when an invocation has started""" - self.__emit_queue_event( - event_name="invocation_started", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "node": node, - "source_node_id": source_node_id, - }, - ) - - def emit_graph_execution_complete( - self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str - ) -> None: - """Emitted when a session has completed all invocations""" - self.__emit_queue_event( - event_name="graph_execution_state_complete", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - }, - ) - - def emit_model_load_started( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - model_config: AnyModelConfig, - ) -> None: - """Emitted when a model is requested""" - self.__emit_queue_event( - event_name="model_load_started", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "model_config": model_config.model_dump(), - }, - ) - - def emit_model_load_completed( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - model_config: AnyModelConfig, - ) -> None: - """Emitted when a model is correctly loaded (returns model info)""" - self.__emit_queue_event( - event_name="model_load_completed", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - "model_config": model_config.model_dump(), - }, - ) - - def emit_session_canceled( - self, - queue_id: str, - queue_item_id: int, - queue_batch_id: str, - graph_execution_state_id: str, - ) -> None: - """Emitted when a session is canceled""" - self.__emit_queue_event( - event_name="session_canceled", - payload={ - "queue_id": queue_id, - "queue_item_id": queue_item_id, - "queue_batch_id": queue_batch_id, - "graph_execution_state_id": graph_execution_state_id, - }, - ) - - def emit_queue_item_status_changed( - self, - session_queue_item: SessionQueueItem, - batch_status: BatchStatus, - queue_status: SessionQueueStatus, - ) -> None: - """Emitted when a queue item's status changes""" - self.__emit_queue_event( - event_name="queue_item_status_changed", - payload={ - "queue_id": queue_status.queue_id, - "queue_item": { - "queue_id": session_queue_item.queue_id, - "item_id": session_queue_item.item_id, - "status": session_queue_item.status, - "batch_id": session_queue_item.batch_id, - "session_id": session_queue_item.session_id, - "error": session_queue_item.error, - "created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None, - "updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None, - "started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None, - "completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None, - }, - "batch_status": batch_status.model_dump(), - "queue_status": queue_status.model_dump(), - }, - ) - - def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None: - """Emitted when a batch is enqueued""" - self.__emit_queue_event( - event_name="batch_enqueued", - payload={ - "queue_id": enqueue_result.queue_id, - "batch_id": enqueue_result.batch.batch_id, - "enqueued": enqueue_result.enqueued, - }, - ) - - def emit_queue_cleared(self, queue_id: str) -> None: - """Emitted when the queue is cleared""" - self.__emit_queue_event( - event_name="queue_cleared", - payload={"queue_id": queue_id}, - ) - - def emit_download_started(self, source: str, download_path: str) -> None: - """ - Emit when a download job is started. - - :param url: The downloaded url - """ - self.__emit_download_event( - event_name="download_started", - payload={"source": source, "download_path": download_path}, - ) - - def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None: - """ - Emit "download_progress" events at regular intervals during a download job. - - :param source: The downloaded source - :param download_path: The local downloaded file - :param current_bytes: Number of bytes downloaded so far - :param total_bytes: The size of the file being downloaded (if known) - """ - self.__emit_download_event( - event_name="download_progress", - payload={ - "source": source, - "download_path": download_path, - "current_bytes": current_bytes, - "total_bytes": total_bytes, - }, - ) - - def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None: - """ - Emit a "download_complete" event at the end of a successful download. - - :param source: Source URL - :param download_path: Path to the locally downloaded file - :param total_bytes: The size of the downloaded file - """ - self.__emit_download_event( - event_name="download_complete", - payload={ - "source": source, - "download_path": download_path, - "total_bytes": total_bytes, - }, - ) - - def emit_download_cancelled(self, source: str) -> None: - """Emit a "download_cancelled" event in the event that the download was cancelled by user.""" - self.__emit_download_event( - event_name="download_cancelled", - payload={ - "source": source, - }, - ) - - def emit_download_error(self, source: str, error_type: str, error: str) -> None: - """ - Emit a "download_error" event when an download job encounters an exception. - - :param source: Source URL - :param error_type: The name of the exception that raised the error - :param error: The traceback from this error - """ - self.__emit_download_event( - event_name="download_error", - payload={ - "source": source, - "error_type": error_type, - "error": error, - }, - ) - - def emit_model_install_downloading( - self, - source: str, - local_path: str, - bytes: int, - total_bytes: int, - parts: List[Dict[str, Union[str, int]]], - ) -> None: - """ - Emit at intervals while the install job is in progress (remote models only). - - :param source: Source of the model - :param local_path: Where model is downloading to - :param parts: Progress of downloading URLs that comprise the model, if any. - :param bytes: Number of bytes downloaded so far. - :param total_bytes: Total size of download, including all files. - This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes". - """ - self.__emit_model_event( - event_name="model_install_downloading", - payload={ - "source": source, - "local_path": local_path, - "bytes": bytes, - "total_bytes": total_bytes, - "parts": parts, - }, - ) - - def emit_model_install_running(self, source: str) -> None: - """ - Emit once when an install job becomes active. - - :param source: Source of the model; local path, repo_id or url - """ - self.__emit_model_event( - event_name="model_install_running", - payload={"source": source}, - ) - - def emit_model_install_completed(self, source: str, key: str, total_bytes: Optional[int] = None) -> None: - """ - Emit when an install job is completed successfully. - - :param source: Source of the model; local path, repo_id or url - :param key: Model config record key - :param total_bytes: Size of the model (may be None for installation of a local path) - """ - self.__emit_model_event( - event_name="model_install_completed", - payload={ - "source": source, - "total_bytes": total_bytes, - "key": key, - }, - ) - - def emit_model_install_cancelled(self, source: str) -> None: - """ - Emit when an install job is cancelled. - - :param source: Source of the model; local path, repo_id or url - """ - self.__emit_model_event( - event_name="model_install_cancelled", - payload={"source": source}, - ) - - def emit_model_install_error( - self, - source: str, - error_type: str, - error: str, - ) -> None: - """ - Emit when an install job encounters an exception. - - :param source: Source of the model - :param error_type: The name of the exception - :param error: A text description of the exception - """ - self.__emit_model_event( - event_name="model_install_error", - payload={ - "source": source, - "error_type": error_type, - "error": error, - }, - ) diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py new file mode 100644 index 00000000000..3779d086a9c --- /dev/null +++ b/invokeai/app/services/events/events_common.py @@ -0,0 +1,532 @@ +from abc import ABC +from enum import Enum +from typing import Any, ClassVar, Coroutine, Optional, Protocol, TypeAlias, TypeVar + +from fastapi_events.handlers.local import local_handler +from fastapi_events.registry.payload_schema import registry as payload_schema +from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny + +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput +from invokeai.app.services.session_processor.session_processor_common import ProgressImage +from invokeai.app.services.session_queue.session_queue_common import ( + QUEUE_ITEM_STATUS, + BatchStatus, + EnqueueBatchResult, + SessionQueueItem, + SessionQueueStatus, +) +from invokeai.app.util.misc import get_timestamp +from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType + + +class EventType(str, Enum): + QUEUE = "queue" + MODEL = "model" + DOWNLOAD = "download" + + +class AppEvent(BaseModel, ABC): + """Base class for all events. All events must inherit from this class. + + Events must define the following class attributes: + - `__event_name__: str`: The name of the event + - `__event_type__: EventType`: The type of the event + + All other attributes should be defined as normal for a pydantic model. + + A timestamp is automatically added to the event when it is created. + """ + + __event_name__: ClassVar[str] = ... # pyright: ignore [reportAssignmentType] + __event_type__: ClassVar[EventType] = ... # pyright: ignore [reportAssignmentType] + + timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp) + + def __init_subclass__(cls, **kwargs: ConfigDict): + for required_attr in ("__event_name__", "__event_type__"): + if getattr(cls, required_attr) is ...: + raise TypeError(f"{cls.__name__} must define {required_attr}") + + model_config = ConfigDict(json_schema_serialization_defaults_required=True) + + +TEvent = TypeVar("TEvent", bound=AppEvent) + +FastAPIEvent: TypeAlias = tuple[str, TEvent] +""" +A tuple representing a `fastapi-events` event, with the event name and payload. +Provide a generic type to `TEvent` to specify the payload type. +""" + + +class FastAPIEventFunc(Protocol): + def __call__(self, event: FastAPIEvent[Any]) -> Optional[Coroutine[Any, Any, None]]: + ... + + +def register_events(events: list[type[AppEvent]], func: FastAPIEventFunc) -> None: + """Register a function to handle a list of events. + + :param events: A list of event classes to handle + :param func: The function to handle the events + """ + for event in events: + local_handler.register(event_name=event.__event_name__, _func=func) + + +class QueueEvent(AppEvent, ABC): + """Base class for queue events""" + + __event_type__ = EventType.QUEUE + __event_name__ = "queue_event" + + queue_id: str = Field(description="The ID of the queue") + + +class QueueItemEvent(QueueEvent, ABC): + """Base class for queue item events""" + + __event_name__ = "queue_item_event" + + item_id: int = Field(description="The ID of the queue item") + batch_id: str = Field(description="The ID of the queue batch") + + +class SessionEvent(QueueItemEvent, ABC): + """Base class for session (aka graph execution state) events""" + + __event_name__ = "session_event" + + session_id: str = Field(description="The ID of the session (aka graph execution state)") + + +class InvocationEvent(SessionEvent, ABC): + """Base class for invocation events""" + + __event_name__ = "invocation_event" + + queue_id: str = Field(description="The ID of the queue") + item_id: int = Field(description="The ID of the queue item") + batch_id: str = Field(description="The ID of the queue batch") + session_id: str = Field(description="The ID of the session (aka graph execution state)") + invocation_id: str = Field(description="The ID of the invocation") + invocation_source_id: str = Field(description="The ID of the prepared invocation's source node") + invocation_type: str = Field(description="The type of invocation") + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class InvocationStartedEvent(InvocationEvent): + """Emitted when an invocation is started""" + + __event_name__ = "invocation_started" + + @classmethod + def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation_id=invocation.id, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + invocation_type=invocation.get_type(), + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class InvocationDenoiseProgressEvent(InvocationEvent): + """Emitted at each step during denoising of an invocation.""" + + __event_name__ = "invocation_denoise_progress" + + progress_image: ProgressImage = Field(description="The progress image sent at each step during processing") + step: int = Field(description="The current step of the invocation") + total_steps: int = Field(description="The total number of steps in the invocation") + + @classmethod + def build( + cls, + queue_item: SessionQueueItem, + invocation: BaseInvocation, + step: int, + total_steps: int, + progress_image: ProgressImage, + ) -> "InvocationDenoiseProgressEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation_id=invocation.id, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + invocation_type=invocation.get_type(), + progress_image=progress_image, + step=step, + total_steps=total_steps, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class InvocationCompleteEvent(InvocationEvent): + """Emitted when an invocation is complete""" + + __event_name__ = "invocation_complete" + + result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation") + + @classmethod + def build( + cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput + ) -> "InvocationCompleteEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation_id=invocation.id, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + invocation_type=invocation.get_type(), + result=result, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class InvocationErrorEvent(InvocationEvent): + """Emitted when an invocation encounters an error""" + + __event_name__ = "invocation_error" + + error_type: str = Field(description="The type of error") + error: str = Field(description="The error message") + + @classmethod + def build( + cls, queue_item: SessionQueueItem, invocation: BaseInvocation, error_type: str, error: str + ) -> "InvocationErrorEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + invocation_id=invocation.id, + invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id], + invocation_type=invocation.get_type(), + error_type=error_type, + error=error, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class SessionStartedEvent(SessionEvent): + """Emitted when a session has started""" + + __event_name__ = "session_started" + + @classmethod + def build(cls, queue_item: SessionQueueItem) -> "SessionStartedEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class SessionCompleteEvent(SessionEvent): + """Emitted when a session has completed all invocations""" + + __event_name__ = "session_complete" + + @classmethod + def build(cls, queue_item: SessionQueueItem) -> "SessionCompleteEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class SessionCanceledEvent(SessionEvent): + """Emitted when a session is canceled""" + + __event_name__ = "session_canceled" + + @classmethod + def build(cls, queue_item: SessionQueueItem) -> "SessionCanceledEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + session_id=queue_item.session_id, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class QueueItemStatusChangedEvent(QueueItemEvent): + """Emitted when a queue item's status changes""" + + __event_name__ = "queue_item_status_changed" + + status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item") + error: Optional[str] = Field(default=None, description="The error message, if any") + created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created") + updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated") + started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started") + completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed") + batch_status: BatchStatus = Field(description="The status of the batch") + queue_status: SessionQueueStatus = Field(description="The status of the queue") + + @classmethod + def build( + cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus + ) -> "QueueItemStatusChangedEvent": + return cls( + queue_id=queue_item.queue_id, + item_id=queue_item.item_id, + batch_id=queue_item.batch_id, + status=queue_item.status, + error=queue_item.error, + created_at=str(queue_item.created_at) if queue_item.created_at else None, + updated_at=str(queue_item.updated_at) if queue_item.updated_at else None, + started_at=str(queue_item.started_at) if queue_item.started_at else None, + completed_at=str(queue_item.completed_at) if queue_item.completed_at else None, + batch_status=batch_status, + queue_status=queue_status, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class BatchEnqueuedEvent(QueueEvent): + """Emitted when a batch is enqueued""" + + __event_name__ = "batch_enqueued" + + batch_id: str = Field(description="The ID of the batch") + enqueued: int = Field(description="The number of invocations enqueued") + requested: int = Field( + description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)" + ) + priority: int = Field(description="The priority of the batch") + + @classmethod + def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent": + return cls( + queue_id=enqueue_result.queue_id, + batch_id=enqueue_result.batch.batch_id, + enqueued=enqueue_result.enqueued, + requested=enqueue_result.requested, + priority=enqueue_result.priority, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class QueueClearedEvent(QueueEvent): + """Emitted when a queue is cleared""" + + __event_name__ = "queue_cleared" + + @classmethod + def build(cls, queue_id: str) -> "QueueClearedEvent": + return cls(queue_id=queue_id) + + +class DownloadEvent(AppEvent, ABC): + """Base class for events associated with a download""" + + __event_type__ = EventType.DOWNLOAD + __event_name__ = "download_event" + + source: str = Field(description="The source of the download") + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class DownloadStartedEvent(DownloadEvent): + """Emitted when a download is started""" + + __event_name__ = "download_started" + + download_path: str = Field(description="The local path where the download is saved") + + @classmethod + def build(cls, source: str, download_path: str) -> "DownloadStartedEvent": + return cls(source=source, download_path=download_path) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class DownloadProgressEvent(DownloadEvent): + """Emitted at intervals during a download""" + + __event_name__ = "download_progress" + + download_path: str = Field(description="The local path where the download is saved") + current_bytes: int = Field(description="The number of bytes downloaded so far") + total_bytes: int = Field(description="The total number of bytes to be downloaded") + + @classmethod + def build(cls, source: str, download_path: str, current_bytes: int, total_bytes: int) -> "DownloadProgressEvent": + return cls(source=source, download_path=download_path, current_bytes=current_bytes, total_bytes=total_bytes) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class DownloadCompleteEvent(DownloadEvent): + """Emitted when a download is completed""" + + __event_name__ = "download_complete" + + download_path: str = Field(description="The local path where the download is saved") + total_bytes: int = Field(description="The total number of bytes downloaded") + + @classmethod + def build(cls, source: str, download_path: str, total_bytes: int) -> "DownloadCompleteEvent": + return cls(source=source, download_path=download_path, total_bytes=total_bytes) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class DownloadCancelledEvent(DownloadEvent): + """Emitted when a download is cancelled""" + + __event_name__ = "download_cancelled" + + @classmethod + def build(cls, source: str) -> "DownloadCancelledEvent": + return cls(source=source) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class DownloadErrorEvent(DownloadEvent): + """Emitted when a download encounters an error""" + + __event_name__ = "download_error" + + error_type: str = Field(description="The type of error") + error: str = Field(description="The error message") + + @classmethod + def build(cls, source: str, error_type: str, error: str) -> "DownloadErrorEvent": + return cls(source=source, error_type=error_type, error=error) + + +class ModelEvent(AppEvent, ABC): + """Base class for events associated with a model""" + + __event_type__ = EventType.MODEL + __event_name__ = "model_event" + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelLoadStartedEvent(ModelEvent): + """Emitted when a model is requested""" + + __event_name__ = "model_load_started" + + config: AnyModelConfig = Field(description="The model's config") + submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any") + + @classmethod + def build(cls, config: AnyModelConfig, submodel_type: SubModelType) -> "ModelLoadStartedEvent": + return cls(config=config, submodel_type=submodel_type) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelLoadCompleteEvent(ModelEvent): + """Emitted when a model is requested""" + + __event_name__ = "model_load_complete" + + config: AnyModelConfig = Field(description="The model's config") + submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any") + + @classmethod + def build(cls, config: AnyModelConfig, submodel_type: SubModelType) -> "ModelLoadCompleteEvent": + return cls(config=config, submodel_type=submodel_type) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelInstallDownloadProgressEvent(ModelEvent): + """Emitted at intervals while the install job is in progress (remote models only).""" + + __event_name__ = "model_install_download_progress" + + source: str = Field(description="Source of the model; local path, repo_id or url") + local_path: str = Field(description="Where model is downloading to") + bytes: int = Field(description="Number of bytes downloaded so far") + total_bytes: int = Field(description="Total size of download, including all files") + parts: list[dict[str, int | str]] = Field( + description="Progress of downloading URLs that comprise the model, if any" + ) + + @classmethod + def build( + cls, + source: str, + local_path: str, + bytes: int, + total_bytes: int, + parts: list[dict[str, int | str]], + ) -> "ModelInstallDownloadProgressEvent": + return cls( + source=source, + local_path=local_path, + bytes=bytes, + total_bytes=total_bytes, + parts=parts, + ) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelInstallStartedEvent(ModelEvent): + """Emitted once when an install job becomes active.""" + + __event_name__ = "model_install_started" + + source: str = Field(description="Source of the model; local path, repo_id or url") + + @classmethod + def build(cls, source: str) -> "ModelInstallStartedEvent": + return cls(source=source) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelInstallCompletedEvent(ModelEvent): + """Emitted when an install job is completed successfully.""" + + __event_name__ = "model_install_completed" + + source: str = Field(description="Source of the model; local path, repo_id or url") + key: str = Field(description="Model config record key") + total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)") + + @classmethod + def build(cls, source: str, key: str, total_bytes: Optional[int]) -> "ModelInstallCompletedEvent": + return cls(source=source, key=key, total_bytes=total_bytes) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelInstalLCancelledEvent(ModelEvent): + """Emitted when an install job is cancelled.""" + + __event_name__ = "model_install_cancelled" + + source: str = Field(description="Source of the model; local path, repo_id or url") + + @classmethod + def build(cls, source: str) -> "ModelInstalLCancelledEvent": + return cls(source=source) + + +@payload_schema.register # pyright: ignore [reportUnknownMemberType] +class ModelInstallErrorEvent(ModelEvent): + """Emitted when an install job encounters an exception.""" + + __event_name__ = "model_install_error" + + source: str = Field(description="Source of the model; local path, repo_id or url") + error_type: str = Field(description="The name of the exception") + error: str = Field(description="A text description of the exception") + + @classmethod + def build(cls, source: str, error_type: str, error: str) -> "ModelInstallErrorEvent": + return cls(source=source, error_type=error_type, error=error) diff --git a/invokeai/app/services/events/events_fastapievents.py b/invokeai/app/services/events/events_fastapievents.py new file mode 100644 index 00000000000..a722a109f00 --- /dev/null +++ b/invokeai/app/services/events/events_fastapievents.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) + +import asyncio +import threading +from queue import Empty, Queue + +from fastapi_events.dispatcher import dispatch + +from invokeai.app.services.events.events_common import AppEvent + +from .events_base import EventServiceBase + + +class FastAPIEventService(EventServiceBase): + def __init__(self, event_handler_id: int) -> None: + self.event_handler_id = event_handler_id + self._queue = Queue[AppEvent | None]() + self._stop_event = threading.Event() + asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event)) + + super().__init__() + + def stop(self, *args, **kwargs): + self._stop_event.set() + self._queue.put(None) + + def dispatch(self, event: AppEvent) -> None: + self._queue.put(event) + + async def _dispatch_from_queue(self, stop_event: threading.Event): + """Get events on from the queue and dispatch them, from the correct thread""" + while not stop_event.is_set(): + try: + event = self._queue.get(block=False) + if not event: # Probably stopping + continue + print(event) + dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False) + + except Empty: + await asyncio.sleep(0.1) + pass + + except asyncio.CancelledError as e: + raise e # Raise a proper error diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index cc80333e932..9d75aafde12 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -4,7 +4,6 @@ from abc import ABC, abstractmethod from typing import Optional -from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase @@ -15,18 +14,12 @@ class ModelLoadServiceBase(ABC): """Wrapper around AnyModelLoader.""" @abstractmethod - def load_model( - self, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Given a model's configuration, load it and return the LoadedModel object. :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param submodel: For main (pipeline models), the submodel to fetch. - :param context_data: Invocation context data used for event reporting """ @property diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 3ff7898c0e4..3b925fba689 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -4,8 +4,8 @@ from typing import Optional, Type from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.events.events_common import ModelLoadCompleteEvent, ModelLoadStartedEvent from invokeai.app.services.invoker import Invoker -from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager.load import ( LoadedModel, @@ -51,24 +51,15 @@ def convert_cache(self) -> ModelConvertCacheBase: """Return the checkpoint convert cache used by this loader.""" return self._convert_cache - def load_model( - self, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ Given a model's configuration, load it and return the LoadedModel object. :param model_config: Model configuration record (as returned by ModelRecordBase.get_model()) :param submodel: For main (pipeline models), the submodel to fetch. - :param context: Invocation context used for event reporting """ - if context_data: - self._emit_load_event( - context_data=context_data, - model_config=model_config, - ) + + self._invoker.services.events.dispatch(ModelLoadStartedEvent(config=model_config, submodel_type=submodel_type)) implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore loaded_model: LoadedModel = implementation( @@ -78,36 +69,6 @@ def load_model( convert_cache=self._convert_cache, ).load_model(model_config, submodel_type) - if context_data: - self._emit_load_event( - context_data=context_data, - model_config=model_config, - loaded=True, - ) - return loaded_model - - def _emit_load_event( - self, - context_data: InvocationContextData, - model_config: AnyModelConfig, - loaded: Optional[bool] = False, - ) -> None: - if not self._invoker: - return + self._invoker.services.events.dispatch(ModelLoadCompleteEvent(config=model_config, submodel_type=submodel_type)) - if not loaded: - self._invoker.services.events.emit_model_load_started( - queue_id=context_data.queue_item.queue_id, - queue_item_id=context_data.queue_item.item_id, - queue_batch_id=context_data.queue_item.batch_id, - graph_execution_state_id=context_data.queue_item.session_id, - model_config=model_config, - ) - else: - self._invoker.services.events.emit_model_load_completed( - queue_id=context_data.queue_item.queue_id, - queue_item_id=context_data.queue_item.item_id, - queue_batch_id=context_data.queue_item.batch_id, - graph_execution_state_id=context_data.queue_item.session_id, - model_config=model_config, - ) + return loaded_model diff --git a/invokeai/app/services/model_manager/model_manager_base.py b/invokeai/app/services/model_manager/model_manager_base.py index c25aa6fb47c..0ad53982701 100644 --- a/invokeai/app/services/model_manager/model_manager_base.py +++ b/invokeai/app/services/model_manager/model_manager_base.py @@ -6,7 +6,6 @@ from typing_extensions import Self from invokeai.app.services.invoker import Invoker -from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager.load.load_base import LoadedModel @@ -72,29 +71,16 @@ def stop(self, invoker: Invoker) -> None: @abstractmethod def load_model_by_config( - self, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, + self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None ) -> LoadedModel: pass @abstractmethod - def load_model_by_key( - self, - key: str, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: + def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: pass @abstractmethod def load_model_by_attr( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, + self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None ) -> LoadedModel: pass diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index d029f9e0339..217648f089f 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -6,7 +6,6 @@ from typing_extensions import Self from invokeai.app.services.invoker import Invoker -from invokeai.app.services.shared.invocation_context import InvocationContextData from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry from invokeai.backend.util.logging import InvokeAILogger @@ -63,29 +62,16 @@ def stop(self, invoker: Invoker) -> None: service.stop(invoker) def load_model_by_config( - self, - model_config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, + self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None ) -> LoadedModel: - return self.load.load_model(model_config, submodel_type, context_data) + return self.load.load_model(model_config, submodel_type) - def load_model_by_key( - self, - key: str, - submodel_type: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, - ) -> LoadedModel: + def load_model_by_key(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: config = self.store.get_model(key) - return self.load.load_model(config, submodel_type, context_data) + return self.load.load_model(config, submodel_type) def load_model_by_attr( - self, - model_name: str, - base_model: BaseModelType, - model_type: ModelType, - submodel: Optional[SubModelType] = None, - context_data: Optional[InvocationContextData] = None, + self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None ) -> LoadedModel: """ Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object. @@ -110,7 +96,7 @@ def load_model_by_attr( elif len(configs) > 1: raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.") else: - return self.load.load_model(configs[0], submodel, context_data) + return self.load.load_model(configs[0], submodel) @classmethod def build_model_manager( diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index c0b98220c87..4d98dcd79b3 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -4,11 +4,20 @@ from threading import Event as ThreadEvent from typing import Optional -from fastapi_events.handlers.local import local_handler -from fastapi_events.typing import Event as FastAPIEvent - from invokeai.app.invocations.baseinvocation import BaseInvocation -from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.services.events.events_common import ( + BatchEnqueuedEvent, + FastAPIEvent, + InvocationCompleteEvent, + InvocationErrorEvent, + InvocationStartedEvent, + QueueClearedEvent, + QueueEvent, + SessionCanceledEvent, + SessionCompleteEvent, + SessionStartedEvent, + register_events, +) from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem @@ -31,8 +40,6 @@ def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = self._poll_now_event = ThreadEvent() self._cancel_event = ThreadEvent() - local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event) - self._thread_limit = thread_limit self._thread_semaphore = BoundedSemaphore(thread_limit) self._polling_interval = polling_interval @@ -49,6 +56,8 @@ def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = else None ) + register_events([SessionCanceledEvent, QueueClearedEvent, BatchEnqueuedEvent], self._on_queue_event) + self._thread = Thread( name="session_processor", target=self._process, @@ -67,14 +76,13 @@ def stop(self, *args, **kwargs) -> None: def _poll_now(self) -> None: self._poll_now_event.set() - async def _on_queue_event(self, event: FastAPIEvent) -> None: - event_name = event[1]["event"] - - if event_name == "session_canceled" or event_name == "queue_cleared": + async def _on_queue_event(self, event: FastAPIEvent[QueueEvent]) -> None: + _event_name, payload = event + if isinstance(payload, (SessionCanceledEvent, QueueClearedEvent)): # These both mean we should cancel the current session. self._cancel_event.set() self._poll_now() - elif event_name == "batch_enqueued": + elif isinstance(payload, BatchEnqueuedEvent): self._poll_now() def resume(self) -> SessionProcessorStatus: @@ -114,6 +122,8 @@ def _process( # Get the next session to process self._queue_item = self._invoker.services.session_queue.dequeue() if self._queue_item is not None and resume_event.is_set(): + # Dispatch session started event + self._invoker.services.events.dispatch(SessionStartedEvent.build(queue_item=self._queue_item)) self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") cancel_event.clear() @@ -129,14 +139,9 @@ def _process( # get the source node id to provide to clients (the prepared node id is not as useful) source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id] - # Send starting event - self._invoker.services.events.emit_invocation_started( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session_id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, + # Dispatch invocation started event + self._invoker.services.events.dispatch( + InvocationStartedEvent.build(queue_item=self._queue_item, invocation=self._invocation) ) # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph @@ -164,19 +169,15 @@ def _process( # Save outputs and history self._queue_item.session.complete(self._invocation.id, outputs) - # Send complete event - self._invoker.services.events.emit_invocation_complete( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - result=outputs.model_dump(), + # Dispatch invocation complete event + self._invoker.services.events.dispatch( + InvocationCompleteEvent.build( + queue_item=self._queue_item, invocation=self._invocation, result=outputs + ) ) except KeyboardInterrupt: - # TODO(MM2): Create an event for this + # TODO(MM2): I don't think this is ever raised... pass except CanceledException: @@ -201,27 +202,22 @@ def _process( f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}" ) - # Send error event - self._invoker.services.events.emit_invocation_error( - queue_batch_id=self._queue_item.session_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - error_type=e.__class__.__name__, - error=error, + # Dispatch invocation error event + self._invoker.services.events.dispatch( + InvocationErrorEvent.build( + queue_item=self._queue_item, + invocation=self._invocation, + error_type=e.__class__.__name__, + error=error, + ) ) pass # The session is complete if the all invocations are complete or there was an error if self._queue_item.session.is_complete() or cancel_event.is_set(): - # Send complete event - self._invoker.services.events.emit_graph_execution_complete( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, + # Dispatch session complete event + self._invoker.services.events.dispatch( + SessionCompleteEvent.build(queue_item=self._queue_item) ) # If we are profiling, stop the profiler and dump the profile & stats if self._profiler: diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 7af9f0e08cd..02b7b92b2f2 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -2,10 +2,16 @@ import threading from typing import Optional, Union, cast -from fastapi_events.handlers.local import local_handler -from fastapi_events.typing import Event as FastAPIEvent - -from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.services.events.events_common import ( + BatchEnqueuedEvent, + FastAPIEvent, + InvocationErrorEvent, + QueueClearedEvent, + QueueItemStatusChangedEvent, + SessionCanceledEvent, + SessionCompleteEvent, + register_events, +) from invokeai.app.services.invoker import Invoker from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase from invokeai.app.services.session_queue.session_queue_common import ( @@ -41,7 +47,11 @@ def start(self, invoker: Invoker) -> None: self.__invoker = invoker self._set_in_progress_to_canceled() prune_result = self.prune(DEFAULT_QUEUE_ID) - local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event) + + register_events(events=[InvocationErrorEvent], func=self._handle_error_event) + register_events(events=[SessionCompleteEvent], func=self._handle_complete_event) + register_events(events=[SessionCanceledEvent], func=self._handle_cancel_event) + if prune_result.deleted > 0: self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items") @@ -51,51 +61,35 @@ def __init__(self, db: SqliteDatabase) -> None: self.__conn = db.conn self.__cursor = self.__conn.cursor() - def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool: - return event[1]["event"] in match_in - - async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent: - event_name = event[1]["event"] - - # This was a match statement, but match is not supported on python 3.9 - if event_name == "graph_execution_state_complete": - await self._handle_complete_event(event) - elif event_name == "invocation_error": - await self._handle_error_event(event) - elif event_name == "session_canceled": - await self._handle_cancel_event(event) - return event - - async def _handle_complete_event(self, event: FastAPIEvent) -> None: + async def _handle_complete_event(self, event: FastAPIEvent[SessionCompleteEvent]) -> None: try: - item_id = event[1]["data"]["queue_item_id"] # When a queue item has an error, we get an error event, then a completed event. # Mark the queue item completed only if it isn't already marked completed, e.g. # by a previously-handled error event. - queue_item = self.get_queue_item(item_id) + _event_name, payload = event + + queue_item = self.get_queue_item(payload.item_id) if queue_item.status not in ["completed", "failed", "canceled"]: - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed") + self._set_queue_item_status(item_id=payload.item_id, status="completed") except SessionQueueItemNotFoundError: - return + pass - async def _handle_error_event(self, event: FastAPIEvent) -> None: + async def _handle_error_event(self, event: FastAPIEvent[InvocationErrorEvent]) -> None: try: - item_id = event[1]["data"]["queue_item_id"] - error = event[1]["data"]["error"] - queue_item = self.get_queue_item(item_id) + _event_name, payload = event # always set to failed if have an error, even if previously the item was marked completed or canceled - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error) + self._set_queue_item_status(item_id=payload.item_id, status="failed", error=payload.error) except SessionQueueItemNotFoundError: - return + pass - async def _handle_cancel_event(self, event: FastAPIEvent) -> None: + async def _handle_cancel_event(self, event: FastAPIEvent[SessionCanceledEvent]) -> None: try: - item_id = event[1]["data"]["queue_item_id"] - queue_item = self.get_queue_item(item_id) + _event_name, payload = event + queue_item = self.get_queue_item(payload.item_id) if queue_item.status not in ["completed", "failed", "canceled"]: - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled") + self._set_queue_item_status(item_id=payload.item_id, status="canceled") except SessionQueueItemNotFoundError: - return + pass def _set_in_progress_to_canceled(self) -> None: """ @@ -190,7 +184,7 @@ def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBa batch=batch, priority=priority, ) - self.__invoker.services.events.emit_batch_enqueued(enqueue_result) + self.__invoker.services.events.dispatch(BatchEnqueuedEvent.build(enqueue_result)) return enqueue_result def dequeue(self) -> Optional[SessionQueueItem]: @@ -292,10 +286,8 @@ def _set_queue_item_status( queue_item = self.get_queue_item(item_id) batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id) queue_status = self.get_queue_status(queue_id=queue_item.queue_id) - self.__invoker.services.events.emit_queue_item_status_changed( - session_queue_item=queue_item, - batch_status=batch_status, - queue_status=queue_status, + self.__invoker.services.events.dispatch( + QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status) ) return queue_item @@ -384,7 +376,7 @@ def clear(self, queue_id: str) -> ClearResult: raise finally: self.__lock.release() - self.__invoker.services.events.emit_queue_cleared(queue_id) + self.__invoker.services.events.dispatch(QueueClearedEvent.build(queue_id)) return ClearResult(deleted=count) def prune(self, queue_id: str) -> PruneResult: @@ -429,12 +421,7 @@ def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> Sessio if queue_item.status not in ["canceled", "failed", "completed"]: status = "failed" if error is not None else "canceled" queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here - self.__invoker.services.events.emit_session_canceled( - queue_item_id=queue_item.item_id, - queue_id=queue_item.queue_id, - queue_batch_id=queue_item.batch_id, - graph_execution_state_id=queue_item.session_id, - ) + self.__invoker.services.events.dispatch(SessionCanceledEvent.build(queue_item)) return queue_item def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult: @@ -470,18 +457,11 @@ def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBa ) self.__conn.commit() if current_queue_item is not None and current_queue_item.batch_id in batch_ids: - self.__invoker.services.events.emit_session_canceled( - queue_item_id=current_queue_item.item_id, - queue_id=current_queue_item.queue_id, - queue_batch_id=current_queue_item.batch_id, - graph_execution_state_id=current_queue_item.session_id, - ) + self.__invoker.services.events.dispatch(SessionCanceledEvent.build(current_queue_item)) batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) queue_status = self.get_queue_status(queue_id=queue_id) - self.__invoker.services.events.emit_queue_item_status_changed( - session_queue_item=current_queue_item, - batch_status=batch_status, - queue_status=queue_status, + self.__invoker.services.events.dispatch( + QueueItemStatusChangedEvent.build(current_queue_item, batch_status, queue_status) ) except Exception: self.__conn.rollback() @@ -521,18 +501,11 @@ def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult: ) self.__conn.commit() if current_queue_item is not None and current_queue_item.queue_id == queue_id: - self.__invoker.services.events.emit_session_canceled( - queue_item_id=current_queue_item.item_id, - queue_id=current_queue_item.queue_id, - queue_batch_id=current_queue_item.batch_id, - graph_execution_state_id=current_queue_item.session_id, - ) + self.__invoker.services.events.dispatch(SessionCanceledEvent.build(current_queue_item)) batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id) queue_status = self.get_queue_status(queue_id=queue_id) - self.__invoker.services.events.emit_queue_item_status_changed( - session_queue_item=current_queue_item, - batch_status=batch_status, - queue_status=queue_status, + self.__invoker.services.events.dispatch( + QueueItemStatusChangedEvent.build(current_queue_item, batch_status, queue_status) ) except Exception: self.__conn.rollback() diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 31064a5e7cc..62b5a6b858c 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -284,9 +284,7 @@ def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> Loaded # The model manager emits events as it loads the model. It needs the context data to build # the event payloads. - return self._services.model_manager.load_model_by_key( - key=key, submodel_type=submodel_type, context_data=self._data - ) + return self._services.model_manager.load_model_by_key(key=key, submodel_type=submodel_type) def load_by_attrs( self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None @@ -300,11 +298,7 @@ def load_by_attrs( :param submodel: For main (pipeline models), the submodel to fetch """ return self._services.model_manager.load_model_by_attr( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, - context_data=self._data, + model_name=model_name, base_model=base_model, model_type=model_type, submodel=submodel ) def get_config(self, key: str) -> AnyModelConfig: diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 8cb59f5b3aa..10231bd4b7a 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -39,6 +39,9 @@ def stable_diffusion_step_callback( if is_canceled(): raise CanceledException + # TODO(psyche): Had to put this import here to avoid circular dependencies... fix me! + from invokeai.app.services.events.events_common import InvocationDenoiseProgressEvent + # Some schedulers report not only the noisy latents at the current timestep, # but also their estimate so far of what the de-noised latents will be. Use # that estimate if it is available. @@ -113,15 +116,12 @@ def stable_diffusion_step_callback( dataURL = image_to_dataURL(image, image_format="JPEG") - events.emit_generator_progress( - queue_id=context_data.queue_item.queue_id, - queue_item_id=context_data.queue_item.item_id, - queue_batch_id=context_data.queue_item.batch_id, - graph_execution_state_id=context_data.queue_item.session_id, - node_id=context_data.invocation.id, - source_node_id=context_data.source_invocation_id, - progress_image=ProgressImage(width=width, height=height, dataURL=dataURL), - step=intermediate_state.step, - order=intermediate_state.order, - total_steps=intermediate_state.total_steps, + events.dispatch( + InvocationDenoiseProgressEvent.build( + queue_item=context_data.queue_item, + invocation=context_data.invocation, + step=intermediate_state.step, + total_steps=intermediate_state.total_steps * intermediate_state.order, + progress_image=ProgressImage(dataURL=dataURL, width=width, height=height), + ) )