Skip to content

Commit

Permalink
feat(events): add dynamic invocation & result validators
Browse files Browse the repository at this point in the history
This is required to get these event fields to deserialize correctly. If omitted, pydantic uses `BaseInvocation`/`BaseInvocationOutput`, which is not correct.

This is similar to the workaround in the `Graph` and `GraphExecutionState` classes where we need to fanagle pydantic with manual validation handling.
  • Loading branch information
psychedelicious committed May 28, 2024
1 parent 91214db commit c2e8b7e
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion invokeai/app/services/events/events_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, field_validator

from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
Expand Down Expand Up @@ -101,6 +101,14 @@ class InvocationEventBase(QueueItemEventBase):
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")

@field_validator("invocation", mode="plain")
@classmethod
def validate_invocation(cls, v: Any):
"""Validates the invocation using the dynamic type adapter."""

invocation = BaseInvocation.get_typeadapter().validate_python(v)
return invocation


@payload_schema.register
class InvocationStartedEvent(InvocationEventBase):
Expand Down Expand Up @@ -176,6 +184,14 @@ class InvocationCompleteEvent(InvocationEventBase):

result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")

@field_validator("result", mode="plain")
@classmethod
def validate_results(cls, v: Any):
"""Validates the invocation result using the dynamic type adapter."""

result = BaseInvocationOutput.get_typeadapter().validate_python(v)
return result

@classmethod
def build(
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
Expand Down

0 comments on commit c2e8b7e

Please sign in to comment.