-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(events): use pydantic schemas for events
Our events handling and implementation is not very friendly: - We have to define functions for every type of event we want to emit and call those specific functions. - Adding or removing data from payloads requires changes wherever the events are dispatched from, and in the events service. - We have no type safety for events and need to rely on string matching and dict access when interacting with events. `fastapi_events` has a neat feature where you can create a pydantic model as an event payload, give it an `__event_name__` attr, and then dispatch the model directly. This allows us to eliminate a layer of indirection and some unpleasant complexity: - We do not need functions for every event type. Define the event in a single model, and dispatch the model directly. The events service only has a single `dispatch` method. - Event handler callbacks get type hints for their event payloads, and can use `isinstance` on them if needed. *see note below* - Event payload construction is now the responsibility of the event itself, not the service. Every event model has a `build` class method, encapsulating this logic. - We can generate OpenAPI schemas for the event payloads and get type safety on the frontend. Previously, the types were manually created (and bugs _have_ occurred when a payload changed on the backend). - When registering event callbacks, we can now register the event model itself instead of its event name. Impossible to make typos. This commit moves the backend over to this improved event handling setup. *Note* Actually, `fastapi_events` has baked in conversion of pydantic events to dicts, so event callbacks get an untyped dict instead of a pydantic model. I've raised a PR to make this behaviour configurable: melvinkcx/fastapi-events#57 I used my `fastapi_events` PR branch while working on this PR, so you'd need to use it to test: https://github.com/psychedelicious/fastapi-events/tree/psyche/feat/payload-schema-dump
- Loading branch information
1 parent
3ccb4e6
commit 1ffae51
Showing
15 changed files
with
869 additions
and
800 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,66 +1,119 @@ | ||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) | ||
|
||
from typing import Any | ||
|
||
from fastapi import FastAPI | ||
from fastapi_events.handlers.local import local_handler | ||
from fastapi_events.typing import Event | ||
from pydantic import BaseModel | ||
from socketio import ASGIApp, AsyncServer | ||
|
||
from ..services.events.events_base import EventServiceBase | ||
from invokeai.app.services.events.events_common import ( | ||
BatchEnqueuedEvent, | ||
BulkImageDownloadCompleteEvent, | ||
BulkImageDownloadErrorEvent, | ||
BulkImageDownloadEvent, | ||
BulkImageDownloadStartedEvent, | ||
FastAPIEvent, | ||
InvocationCompleteEvent, | ||
InvocationDenoiseProgressEvent, | ||
InvocationErrorEvent, | ||
InvocationStartedEvent, | ||
ModelEvent, | ||
ModelInstalLCancelledEvent, | ||
ModelInstallCompletedEvent, | ||
ModelInstallDownloadProgressEvent, | ||
ModelInstallErrorEvent, | ||
ModelInstallStartedEvent, | ||
ModelLoadCompleteEvent, | ||
ModelLoadStartedEvent, | ||
QueueClearedEvent, | ||
QueueEvent, | ||
QueueItemStatusChangedEvent, | ||
SessionCanceledEvent, | ||
SessionCompleteEvent, | ||
SessionStartedEvent, | ||
register_events, | ||
) | ||
|
||
|
||
class QueueSubscriptionEvent(BaseModel): | ||
queue_id: str | ||
|
||
|
||
class BulkDownloadSubscriptionEvent(BaseModel): | ||
bulk_download_id: str | ||
|
||
|
||
class SocketIO: | ||
__sio: AsyncServer | ||
__app: ASGIApp | ||
|
||
__sub_queue: str = "subscribe_queue" | ||
__unsub_queue: str = "unsubscribe_queue" | ||
_sub_queue = "subscribe_queue" | ||
_unsub_queue = "unsubscribe_queue" | ||
|
||
__sub_bulk_download: str = "subscribe_bulk_download" | ||
__unsub_bulk_download: str = "unsubscribe_bulk_download" | ||
_sub_bulk_download = "subscribe_bulk_download" | ||
_unsub_bulk_download = "unsubscribe_bulk_download" | ||
|
||
def __init__(self, app: FastAPI): | ||
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") | ||
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io") | ||
app.mount("/ws", self.__app) | ||
|
||
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue) | ||
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue) | ||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event) | ||
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event) | ||
|
||
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download) | ||
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download) | ||
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event) | ||
|
||
async def _handle_queue_event(self, event: Event): | ||
await self.__sio.emit( | ||
event=event[1]["event"], | ||
data=event[1]["data"], | ||
room=event[1]["data"]["queue_id"], | ||
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") | ||
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io") | ||
app.mount("/ws", self._app) | ||
|
||
self._sio.on(self._sub_queue, handler=self._handle_sub_queue) | ||
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue) | ||
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download) | ||
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download) | ||
|
||
register_events( | ||
[ | ||
InvocationStartedEvent, | ||
InvocationDenoiseProgressEvent, | ||
InvocationCompleteEvent, | ||
InvocationErrorEvent, | ||
SessionStartedEvent, | ||
SessionCompleteEvent, | ||
SessionCanceledEvent, | ||
QueueItemStatusChangedEvent, | ||
BatchEnqueuedEvent, | ||
QueueClearedEvent, | ||
], | ||
self._handle_queue_event, | ||
) | ||
|
||
register_events( | ||
[ | ||
ModelLoadStartedEvent, | ||
ModelLoadCompleteEvent, | ||
ModelInstallDownloadProgressEvent, | ||
ModelInstallStartedEvent, | ||
ModelInstallCompletedEvent, | ||
ModelInstalLCancelledEvent, | ||
ModelInstallErrorEvent, | ||
], | ||
self._handle_model_event, | ||
) | ||
|
||
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None: | ||
if "queue_id" in data: | ||
await self.__sio.enter_room(sid, data["queue_id"]) | ||
register_events( | ||
[BulkImageDownloadStartedEvent, BulkImageDownloadCompleteEvent, BulkImageDownloadErrorEvent], | ||
self._handle_bulk_image_download_event, | ||
) | ||
|
||
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None: | ||
if "queue_id" in data: | ||
await self.__sio.leave_room(sid, data["queue_id"]) | ||
async def _handle_sub_queue(self, sid: str, data: Any) -> None: | ||
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id) | ||
|
||
async def _handle_model_event(self, event: Event) -> None: | ||
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"]) | ||
async def _handle_unsub_queue(self, sid: str, data: Any) -> None: | ||
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id) | ||
|
||
async def _handle_bulk_download_event(self, event: Event): | ||
await self.__sio.emit( | ||
event=event[1]["event"], | ||
data=event[1]["data"], | ||
room=event[1]["data"]["bulk_download_id"], | ||
) | ||
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None: | ||
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) | ||
|
||
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None: | ||
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) | ||
|
||
async def _handle_queue_event(self, event: FastAPIEvent[QueueEvent]): | ||
event_name, payload = event | ||
await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.queue_id) | ||
|
||
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs): | ||
if "bulk_download_id" in data: | ||
await self.__sio.enter_room(sid, data["bulk_download_id"]) | ||
async def _handle_model_event(self, event: FastAPIEvent[ModelEvent]) -> None: | ||
event_name, payload = event | ||
await self._sio.emit(event=event_name, data=payload.model_dump()) | ||
|
||
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs): | ||
if "bulk_download_id" in data: | ||
await self.__sio.leave_room(sid, data["bulk_download_id"]) | ||
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkImageDownloadEvent]) -> None: | ||
event_name, payload = event | ||
await self._sio.emit(event=event_name, data=payload.model_dump()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.