Skip to content

Commit

Permalink
refactor(events): use pydantic schemas for events
Browse files Browse the repository at this point in the history
Our events handling and implementation is not very friendly:
- We have to define functions for every type of event we want to emit and call those specific functions.
- Adding or removing data from payloads requires changes wherever the events are dispatched from, and in the events service.
- We have no type safety for events and need to rely on string matching and dict access when interacting with events.

`fastapi_events` has a neat feature where you can create a pydantic model as an event payload, give it an `__event_name__` attr, and then dispatch the model directly.

This allows us to eliminate a layer of indirection and some unpleasant complexity:
- We do not need functions for every event type. Define the event in a single model, and dispatch the model directly. The events service only has a single `dispatch` method.
- Event handler callbacks get type hints for their event payloads, and can use `isinstance` on them if needed. *see note below*
- Event payload construction is now the responsibility of the event itself, not the service. Every event model has a `build` class method, encapsulating this logic.
- We can generate OpenAPI schemas for the event payloads and get type safety on the frontend. Previously, the types were manually created (and bugs _have_ occurred when a payload changed on the backend).
- When registering event callbacks, we can now register the event model itself instead of its event name. Impossible to make typos.

This commit moves the backend over to this improved event handling setup.

*Note*
Actually, `fastapi_events` has baked in conversion of pydantic events to dicts, so event callbacks get an untyped dict instead of a pydantic model. I've raised a PR to make this behaviour configurable: melvinkcx/fastapi-events#57

I used my `fastapi_events` PR branch while working on this PR, so you'd need to use it to test: https://github.com/psychedelicious/fastapi-events/tree/psyche/feat/payload-schema-dump
  • Loading branch information
psychedelicious committed Feb 19, 2024
1 parent 9d79ee8 commit f82632f
Show file tree
Hide file tree
Showing 15 changed files with 774 additions and 733 deletions.
2 changes: 1 addition & 1 deletion invokeai/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..services.boards.boards_default import BoardService
from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService
from ..services.events.events_fastapievents import FastAPIEventService
from ..services.image_files.image_files_disk import DiskImageFileStorage
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
from ..services.images.images_default import ImageService
Expand All @@ -32,7 +33,6 @@
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService


# TODO: is there a better way to achieve this?
Expand Down
52 changes: 0 additions & 52 deletions invokeai/app/api/events.py

This file was deleted.

88 changes: 69 additions & 19 deletions invokeai/app/api/sockets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,38 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)

from typing import Any

from fastapi import FastAPI
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from pydantic import BaseModel
from socketio import ASGIApp, AsyncServer

from ..services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
FastAPIEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelEvent,
ModelInstalLCancelledEvent,
ModelInstallCompletedEvent,
ModelInstallDownloadProgressEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueEvent,
QueueItemStatusChangedEvent,
SessionCanceledEvent,
SessionCompleteEvent,
SessionStartedEvent,
register_events,
)


class QueueSubscriptionEvent(BaseModel):
queue_id: str


class SocketIO:
Expand All @@ -19,23 +46,46 @@ def __init__(self, app: FastAPI):

self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)

async def _handle_queue_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["queue_id"],

register_events(
[
InvocationStartedEvent,
InvocationDenoiseProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
SessionStartedEvent,
SessionCompleteEvent,
SessionCanceledEvent,
QueueItemStatusChangedEvent,
BatchEnqueuedEvent,
QueueClearedEvent,
],
self._handle_queue_event,
)

async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.enter_room(sid, data["queue_id"])
register_events(
[
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallStartedEvent,
ModelInstallCompletedEvent,
ModelInstalLCancelledEvent,
ModelInstallErrorEvent,
],
self._handle_model_event,
)

async def _handle_sub_queue(self, sid: str, data: Any) -> None:
await self.__sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)

async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
await self.__sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)

async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.leave_room(sid, data["queue_id"])
async def _handle_queue_event(self, event: FastAPIEvent[QueueEvent]):
event_name, payload = event
await self.__sio.emit(event=event_name, data=payload.model_dump(), room=payload.queue_id)

async def _handle_model_event(self, event: Event) -> None:
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
async def _handle_model_event(self, event: FastAPIEvent[ModelEvent]) -> None:
event_name, payload = event
await self.__sio.emit(event=event_name, data=payload.model_dump())
31 changes: 14 additions & 17 deletions invokeai/app/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# which are imported/used before parse_args() is called will get the default config values instead of the
# values from the command line or config file.
import sys
from typing import cast

from pydantic import BaseModel

from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.version.invokeai_version import __version__
Expand Down Expand Up @@ -32,6 +35,7 @@
from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from fastapi_events.registry.payload_schema import registry as fastapi_events_registry
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available

Expand Down Expand Up @@ -173,23 +177,16 @@ def custom_openapi() -> dict[str, Any]:
invoker_schema["output"] = outputs_ref
invoker_schema["class"] = "invocation"

# This code no longer seems to be necessary?
# Leave it here just in case
#
# from invokeai.backend.model_manager import get_model_config_formats
# formats = get_model_config_formats()
# for model_config_name, enum_set in formats.items():

# if model_config_name in openapi_schema["components"]["schemas"]:
# # print(f"Config with name {name} already defined")
# continue

# openapi_schema["components"]["schemas"][model_config_name] = {
# "title": model_config_name,
# "description": "An enumeration.",
# "type": "string",
# "enum": [v.value for v in enum_set],
# }
# Add all pydantic event schemas registered with fastapi-events
for payload in fastapi_events_registry.data.values():
json_schema = cast(BaseModel, payload).model_json_schema(
mode="serialization", ref_template="#/components/schemas/{model}"
)
if "$defs" in json_schema:
for schema_key, schema in json_schema["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema
del json_schema["$defs"]
openapi_schema["components"]["schemas"][payload.__name__] = json_schema

app.openapi_schema = openapi_schema
return app.openapi_schema
Expand Down
Loading

0 comments on commit f82632f

Please sign in to comment.