Skip to content

Commit

Permalink
Merge pull request #57 from psychedelicious/psyche/feat/payload-schem…
Browse files Browse the repository at this point in the history
…a-dump

feat: allow opt-out of schema dump to dict
  • Loading branch information
melvinkcx authored Feb 20, 2024
2 parents bb3b55f + 5b43e2f commit 708b3a6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 16 deletions.
34 changes: 23 additions & 11 deletions fastapi_events/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def _derive_event_name_and_payload_from_pydantic_model(
event_name_or_model: Union[EventName, PydanticModel],
event_name: EventName,
payload: Payload,
payload_schema_cls_dict_args: Dict[str, Any]
payload_schema_cls_dict_args: Dict[str, Any],
payload_schema_dump: bool
):
"""
Derive event_name and payload from Pydantic model
Expand All @@ -123,10 +124,13 @@ def _derive_event_name_and_payload_from_pydantic_model(

if not payload:
payload_schema_cls_dict_args = payload_schema_cls_dict_args or DEFAULT_PAYLOAD_SCHEMA_CLS_DICT_ARGS
if IS_PYDANTIC_V1:
payload = event_name_or_model.dict(**payload_schema_cls_dict_args)
if payload_schema_dump:
if IS_PYDANTIC_V1:
payload = event_name_or_model.dict(**payload_schema_cls_dict_args)
else:
payload = event_name_or_model.model_dump(**payload_schema_cls_dict_args)
else:
payload = event_name_or_model.model_dump(**payload_schema_cls_dict_args)
payload = event_name_or_model

return event_name, payload

Expand All @@ -135,7 +139,8 @@ def _validate_payload(
event_name: EventName,
payload: Payload,
payload_schema_registry: BaseEventPayloadSchemaRegistry,
payload_schema_cls_dict_args: Dict[str, Any]
payload_schema_cls_dict_args: Dict[str, Any],
payload_schema_dump: bool = True
):
"""
Validate payload if a corresponding payload schema is registered
Expand All @@ -146,10 +151,14 @@ def _validate_payload(
payload_schema_cls = payload_schema_registry.get(event_name)
if payload_schema_cls:
payload_schema_cls_dict_args = payload_schema_cls_dict_args or DEFAULT_PAYLOAD_SCHEMA_CLS_DICT_ARGS
if IS_PYDANTIC_V1:
payload = payload_schema_cls(**(payload or {})).dict(**payload_schema_cls_dict_args)
deserialized_payload = payload_schema_cls(**(payload or {}))
if payload_schema_dump:
if IS_PYDANTIC_V1:
payload = deserialized_payload.dict(**payload_schema_cls_dict_args)
else:
payload = deserialized_payload.model_dump(**payload_schema_cls_dict_args)
else:
payload = payload_schema_cls(**(payload or {})).model_dump(**payload_schema_cls_dict_args)
payload = deserialized_payload
else:
logger.debug("Payload schema for event %s not found. Skipping validation...", event_name)

Expand All @@ -163,7 +172,8 @@ def dispatch(
validate_payload: bool = True,
payload_schema_cls_dict_args: Optional[Dict[str, Any]] = None,
payload_schema_registry: Optional[BaseEventPayloadSchemaRegistry] = None,
middleware_id: Optional[int] = None
middleware_id: Optional[int] = None,
payload_schema_dump: bool = True
) -> None:
"""
A wrapper of the main dispatcher function with additional checks.
Expand All @@ -184,7 +194,8 @@ def dispatch(
event_name_or_model=event_name_or_model,
event_name=event_name,
payload=payload,
payload_schema_cls_dict_args=payload_schema_cls_dict_args
payload_schema_cls_dict_args=payload_schema_cls_dict_args,
payload_schema_dump=payload_schema_dump,
)

# Validate event payload with schema registered
Expand All @@ -194,7 +205,8 @@ def dispatch(
event_name=event_name,
payload=payload,
payload_schema_registry=payload_schema_registry,
payload_schema_cls_dict_args=payload_schema_cls_dict_args
payload_schema_cls_dict_args=payload_schema_cls_dict_args,
payload_schema_dump=payload_schema_dump
)

# OTEL
Expand Down
21 changes: 16 additions & 5 deletions tests/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,12 @@ async def test_suppression_of_events_in_req_res_cycle(
({"user_id": uuid.uuid4()}, True),
({}, True),
(None, True)))
@pytest.mark.parametrize(
"payload_schema_dump",
(True, False)
)
async def test_payload_validation_with_pydantic_in_req_res_cycle(
event_payload, should_raise_error, setup_mocks_for_events_in_req_res_cycle
event_payload, should_raise_error, payload_schema_dump, setup_mocks_for_events_in_req_res_cycle,
):
"""
Test if event payloads are properly validated when a payload schema is registered.
Expand All @@ -88,7 +92,8 @@ class _SignUpEventSchema(pydantic.BaseModel):
dispatch_fn = functools.partial(dispatch,
event_name=UserEvents.SIGNED_UP,
payload=event_payload,
payload_schema_registry=payload_schema)
payload_schema_registry=payload_schema,
payload_schema_dump=payload_schema_dump)

if should_raise_error:
with pytest.raises(pydantic.ValidationError):
Expand All @@ -100,8 +105,12 @@ class _SignUpEventSchema(pydantic.BaseModel):


@pytest.mark.asyncio
@pytest.mark.parametrize(
"payload_schema_dump",
(True, False)
)
async def test_dispatching_with_pydantic_model(
setup_mocks_for_events_in_req_res_cycle, mocker
payload_schema_dump, setup_mocks_for_events_in_req_res_cycle, mocker
):
payload_schema = EventPayloadSchemaRegistry()

Expand All @@ -114,12 +123,14 @@ class UserSignedUpEventSchema(pydantic.BaseModel):

username: str

dispatch(UserSignedUpEventSchema(username="USER_ABC"))
event = UserSignedUpEventSchema(username="USER_ABC")
expected_payload = {"username": "USER_ABC"} if payload_schema_dump else event
dispatch(event, payload_schema_dump=payload_schema_dump)

assert mocks["spy_event_store_ctx_var"].get.called
spy__dispatch.assert_called_with(
event_name="USER_SIGNED_UP",
payload={"username": "USER_ABC"}
payload=expected_payload
)


Expand Down

0 comments on commit 708b3a6

Please sign in to comment.