-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(events): use pydantic schemas for events
Our events handling and implementation is not very friendly: - We have to define functions for every type of event we want to emit and call those specific functions. - Adding or removing data from payloads requires changes wherever the events are dispatched from, and in the events service. - We have no type safety for events and need to rely on string matching and dict access when interacting with events. `fastapi_events` has a neat feature where you can create a pydantic model as an event payload, give it an `__event_name__` attr, and then dispatch the model directly. This allows us to eliminate a layer of indirection and some unpleasant complexity: - We do not need functions for every event type. Define the event in a single model, and dispatch the model directly. The events service only has a single `dispatch` method. - Event handler callbacks get type hints for their event payloads, and can use `isinstance` on them if needed. *see note below* - Event payload construction is now the responsibility of the event itself, not the service. Every event model has a `build` class method, encapsulating this logic. - We can generate OpenAPI schemas for the event payloads and get type safety on the frontend. Previously, the types were manually created (and bugs _have_ occurred when a payload changed on the backend). - When registering event callbacks, we can now register the event model itself instead of its event name. Impossible to make typos. This commit moves the backend over to this improved event handling setup. *Note* Actually, `fastapi_events` has baked in conversion of pydantic events to dicts, so event callbacks get an untyped dict instead of a pydantic model. I've raised a PR to make this behaviour configurable: melvinkcx/fastapi-events#57 I used my `fastapi_events` PR branch while working on this PR, so you'd need to use it to test: https://github.com/psychedelicious/fastapi-events/tree/psyche/feat/payload-schema-dump chore(ui): typegen refactor(ui): update frontend to use new events setup TODO: Support session_started and session_canceled events. feat(events): improved types build: pin fastapi-events to unreleased version It's merged, but not released yet. fix(events): fix merge issue chore(ui): lint feat(events): migrate MM install events to pydantic feat(ui): update events for new mm install chore: bump fastapi_events feat(events): revise how events are dispatched Dispatching the fully-formed event payloads directly ended up causing circular import issues. Revised so that each event has its own `emit_...` method on the events service, which is how it was originally. This also makes it clear what events are valid in the system.
- Loading branch information
1 parent
a386544
commit 60a5869
Showing
39 changed files
with
2,270 additions
and
1,480 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,66 +1,119 @@ | ||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) | ||
|
||
from typing import Any | ||
|
||
from fastapi import FastAPI | ||
from fastapi_events.handlers.local import local_handler | ||
from fastapi_events.typing import Event | ||
from pydantic import BaseModel | ||
from socketio import ASGIApp, AsyncServer | ||
|
||
from ..services.events.events_base import EventServiceBase | ||
from invokeai.app.services.events.events_common import ( | ||
BatchEnqueuedEvent, | ||
BulkDownloadCompleteEvent, | ||
BulkDownloadErrorEvent, | ||
BulkDownloadEvent, | ||
BulkDownloadStartedEvent, | ||
FastAPIEvent, | ||
InvocationCompleteEvent, | ||
InvocationDenoiseProgressEvent, | ||
InvocationErrorEvent, | ||
InvocationStartedEvent, | ||
ModelEvent, | ||
ModelInstallCancelledEvent, | ||
ModelInstallCompleteEvent, | ||
ModelInstallDownloadProgressEvent, | ||
ModelInstallErrorEvent, | ||
ModelInstallStartedEvent, | ||
ModelLoadCompleteEvent, | ||
ModelLoadStartedEvent, | ||
QueueClearedEvent, | ||
QueueEvent, | ||
QueueItemStatusChangedEvent, | ||
SessionCanceledEvent, | ||
SessionCompleteEvent, | ||
SessionStartedEvent, | ||
register_events, | ||
) | ||
|
||
|
||
class QueueSubscriptionEvent(BaseModel): | ||
queue_id: str | ||
|
||
|
||
class BulkDownloadSubscriptionEvent(BaseModel): | ||
bulk_download_id: str | ||
|
||
|
||
class SocketIO: | ||
__sio: AsyncServer | ||
__app: ASGIApp | ||
|
||
__sub_queue: str = "subscribe_queue" | ||
__unsub_queue: str = "unsubscribe_queue" | ||
_sub_queue = "subscribe_queue" | ||
_unsub_queue = "unsubscribe_queue" | ||
|
||
__sub_bulk_download: str = "subscribe_bulk_download" | ||
__unsub_bulk_download: str = "unsubscribe_bulk_download" | ||
_sub_bulk_download = "subscribe_bulk_download" | ||
_unsub_bulk_download = "unsubscribe_bulk_download" | ||
|
||
def __init__(self, app: FastAPI): | ||
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") | ||
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io") | ||
app.mount("/ws", self.__app) | ||
|
||
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue) | ||
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue) | ||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event) | ||
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event) | ||
|
||
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download) | ||
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download) | ||
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event) | ||
|
||
async def _handle_queue_event(self, event: Event): | ||
await self.__sio.emit( | ||
event=event[1]["event"], | ||
data=event[1]["data"], | ||
room=event[1]["data"]["queue_id"], | ||
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") | ||
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io") | ||
app.mount("/ws", self._app) | ||
|
||
self._sio.on(self._sub_queue, handler=self._handle_sub_queue) | ||
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue) | ||
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download) | ||
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download) | ||
|
||
register_events( | ||
{ | ||
InvocationStartedEvent, | ||
InvocationDenoiseProgressEvent, | ||
InvocationCompleteEvent, | ||
InvocationErrorEvent, | ||
SessionStartedEvent, | ||
SessionCompleteEvent, | ||
SessionCanceledEvent, | ||
QueueItemStatusChangedEvent, | ||
BatchEnqueuedEvent, | ||
QueueClearedEvent, | ||
}, | ||
self._handle_queue_event, | ||
) | ||
|
||
register_events( | ||
{ | ||
ModelLoadStartedEvent, | ||
ModelLoadCompleteEvent, | ||
ModelInstallDownloadProgressEvent, | ||
ModelInstallStartedEvent, | ||
ModelInstallCompleteEvent, | ||
ModelInstallCancelledEvent, | ||
ModelInstallErrorEvent, | ||
}, | ||
self._handle_model_event, | ||
) | ||
|
||
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None: | ||
if "queue_id" in data: | ||
await self.__sio.enter_room(sid, data["queue_id"]) | ||
register_events( | ||
{BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}, | ||
self._handle_bulk_image_download_event, | ||
) | ||
|
||
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None: | ||
if "queue_id" in data: | ||
await self.__sio.leave_room(sid, data["queue_id"]) | ||
async def _handle_sub_queue(self, sid: str, data: Any) -> None: | ||
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id) | ||
|
||
async def _handle_model_event(self, event: Event) -> None: | ||
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"]) | ||
async def _handle_unsub_queue(self, sid: str, data: Any) -> None: | ||
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id) | ||
|
||
async def _handle_bulk_download_event(self, event: Event): | ||
await self.__sio.emit( | ||
event=event[1]["event"], | ||
data=event[1]["data"], | ||
room=event[1]["data"]["bulk_download_id"], | ||
) | ||
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None: | ||
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) | ||
|
||
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None: | ||
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) | ||
|
||
async def _handle_queue_event(self, event: FastAPIEvent[QueueEvent]): | ||
event_name, payload = event | ||
await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.queue_id) | ||
|
||
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs): | ||
if "bulk_download_id" in data: | ||
await self.__sio.enter_room(sid, data["bulk_download_id"]) | ||
async def _handle_model_event(self, event: FastAPIEvent[ModelEvent]) -> None: | ||
event_name, payload = event | ||
await self._sio.emit(event=event_name, data=payload.model_dump()) | ||
|
||
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs): | ||
if "bulk_download_id" in data: | ||
await self.__sio.leave_room(sid, data["bulk_download_id"]) | ||
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEvent]) -> None: | ||
event_name, payload = event | ||
await self._sio.emit(event=event_name, data=payload.model_dump()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.