From 5b43e2ff447bd9cb37d63288c4940b0016084140 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 19 Feb 2024 17:26:08 +1100 Subject: [PATCH] feat: allow opt-out of schema dump to dict - Adds `payload_schema_dump: bool = True` arg to `dispatch`, `_derive_event_name_and_payload_from_pydantic_model` and `_validate_payload`. When `False`, the event handlers are called "live" pydantic model for payload. The default behaviour is unchanged. - Updates tests for this new functionality. Closes #56 --- fastapi_events/dispatcher.py | 34 +++++++++++++++++++++++----------- tests/test_dispatcher.py | 21 ++++++++++++++++----- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/fastapi_events/dispatcher.py b/fastapi_events/dispatcher.py index 01677b5..2e1a381 100644 --- a/fastapi_events/dispatcher.py +++ b/fastapi_events/dispatcher.py @@ -110,7 +110,8 @@ def _derive_event_name_and_payload_from_pydantic_model( event_name_or_model: Union[EventName, PydanticModel], event_name: EventName, payload: Payload, - payload_schema_cls_dict_args: Dict[str, Any] + payload_schema_cls_dict_args: Dict[str, Any], + payload_schema_dump: bool ): """ Derive event_name and payload from Pydantic model @@ -123,10 +124,13 @@ def _derive_event_name_and_payload_from_pydantic_model( if not payload: payload_schema_cls_dict_args = payload_schema_cls_dict_args or DEFAULT_PAYLOAD_SCHEMA_CLS_DICT_ARGS - if IS_PYDANTIC_V1: - payload = event_name_or_model.dict(**payload_schema_cls_dict_args) + if payload_schema_dump: + if IS_PYDANTIC_V1: + payload = event_name_or_model.dict(**payload_schema_cls_dict_args) + else: + payload = event_name_or_model.model_dump(**payload_schema_cls_dict_args) else: - payload = event_name_or_model.model_dump(**payload_schema_cls_dict_args) + payload = event_name_or_model return event_name, payload @@ -135,7 +139,8 @@ def _validate_payload( event_name: EventName, payload: Payload, payload_schema_registry: BaseEventPayloadSchemaRegistry, - payload_schema_cls_dict_args: Dict[str, Any] + payload_schema_cls_dict_args: Dict[str, Any], + payload_schema_dump: bool = True ): """ Validate payload if a corresponding payload schema is registered @@ -146,10 +151,14 @@ def _validate_payload( payload_schema_cls = payload_schema_registry.get(event_name) if payload_schema_cls: payload_schema_cls_dict_args = payload_schema_cls_dict_args or DEFAULT_PAYLOAD_SCHEMA_CLS_DICT_ARGS - if IS_PYDANTIC_V1: - payload = payload_schema_cls(**(payload or {})).dict(**payload_schema_cls_dict_args) + deserialized_payload = payload_schema_cls(**(payload or {})) + if payload_schema_dump: + if IS_PYDANTIC_V1: + payload = deserialized_payload.dict(**payload_schema_cls_dict_args) + else: + payload = deserialized_payload.model_dump(**payload_schema_cls_dict_args) else: - payload = payload_schema_cls(**(payload or {})).model_dump(**payload_schema_cls_dict_args) + payload = deserialized_payload else: logger.debug("Payload schema for event %s not found. Skipping validation...", event_name) @@ -163,7 +172,8 @@ def dispatch( validate_payload: bool = True, payload_schema_cls_dict_args: Optional[Dict[str, Any]] = None, payload_schema_registry: Optional[BaseEventPayloadSchemaRegistry] = None, - middleware_id: Optional[int] = None + middleware_id: Optional[int] = None, + payload_schema_dump: bool = True ) -> None: """ A wrapper of the main dispatcher function with additional checks. @@ -184,7 +194,8 @@ def dispatch( event_name_or_model=event_name_or_model, event_name=event_name, payload=payload, - payload_schema_cls_dict_args=payload_schema_cls_dict_args + payload_schema_cls_dict_args=payload_schema_cls_dict_args, + payload_schema_dump=payload_schema_dump, ) # Validate event payload with schema registered @@ -194,7 +205,8 @@ def dispatch( event_name=event_name, payload=payload, payload_schema_registry=payload_schema_registry, - payload_schema_cls_dict_args=payload_schema_cls_dict_args + payload_schema_cls_dict_args=payload_schema_cls_dict_args, + payload_schema_dump=payload_schema_dump ) # OTEL diff --git a/tests/test_dispatcher.py b/tests/test_dispatcher.py index 595b287..b987d3c 100644 --- a/tests/test_dispatcher.py +++ b/tests/test_dispatcher.py @@ -68,8 +68,12 @@ async def test_suppression_of_events_in_req_res_cycle( ({"user_id": uuid.uuid4()}, True), ({}, True), (None, True))) +@pytest.mark.parametrize( + "payload_schema_dump", + (True, False) +) async def test_payload_validation_with_pydantic_in_req_res_cycle( - event_payload, should_raise_error, setup_mocks_for_events_in_req_res_cycle + event_payload, should_raise_error, payload_schema_dump, setup_mocks_for_events_in_req_res_cycle, ): """ Test if event payloads are properly validated when a payload schema is registered. @@ -88,7 +92,8 @@ class _SignUpEventSchema(pydantic.BaseModel): dispatch_fn = functools.partial(dispatch, event_name=UserEvents.SIGNED_UP, payload=event_payload, - payload_schema_registry=payload_schema) + payload_schema_registry=payload_schema, + payload_schema_dump=payload_schema_dump) if should_raise_error: with pytest.raises(pydantic.ValidationError): @@ -100,8 +105,12 @@ class _SignUpEventSchema(pydantic.BaseModel): @pytest.mark.asyncio +@pytest.mark.parametrize( + "payload_schema_dump", + (True, False) +) async def test_dispatching_with_pydantic_model( - setup_mocks_for_events_in_req_res_cycle, mocker + payload_schema_dump, setup_mocks_for_events_in_req_res_cycle, mocker ): payload_schema = EventPayloadSchemaRegistry() @@ -114,12 +123,14 @@ class UserSignedUpEventSchema(pydantic.BaseModel): username: str - dispatch(UserSignedUpEventSchema(username="USER_ABC")) + event = UserSignedUpEventSchema(username="USER_ABC") + expected_payload = {"username": "USER_ABC"} if payload_schema_dump else event + dispatch(event, payload_schema_dump=payload_schema_dump) assert mocks["spy_event_store_ctx_var"].get.called spy__dispatch.assert_called_with( event_name="USER_SIGNED_UP", - payload={"username": "USER_ABC"} + payload=expected_payload )