From a863e8fb7bb314154bccb0829550f780faf84913 Mon Sep 17 00:00:00 2001 From: p1c2u Date: Fri, 15 Sep 2023 15:06:47 +0000 Subject: [PATCH] Skip response validation option --- docs/integrations.rst | 45 +++++ openapi_core/contrib/django/middlewares.py | 14 +- openapi_core/contrib/falcon/middlewares.py | 27 +-- openapi_core/contrib/flask/__init__.py | 2 + openapi_core/contrib/flask/decorators.py | 25 +-- .../data/v3.0/djangoproject/tags/__init__.py | 0 .../data/v3.0/djangoproject/tags/views.py | 13 ++ .../django/data/v3.0/djangoproject/urls.py | 13 +- .../contrib/django/test_django_project.py | 23 +++ tests/integration/contrib/flask/conftest.py | 37 ++++ .../flask/data/v3.0/flask_factory.yaml | 57 ++++++ .../contrib/flask/test_flask_decorator.py | 167 ++++++++++++------ .../contrib/flask/test_flask_validator.py | 31 +--- .../contrib/flask/test_flask_views.py | 134 +++++++------- 14 files changed, 419 insertions(+), 169 deletions(-) create mode 100644 tests/integration/contrib/django/data/v3.0/djangoproject/tags/__init__.py create mode 100644 tests/integration/contrib/django/data/v3.0/djangoproject/tags/views.py create mode 100644 tests/integration/contrib/flask/conftest.py diff --git a/docs/integrations.rst b/docs/integrations.rst index 96229b91..82422989 100644 --- a/docs/integrations.rst +++ b/docs/integrations.rst @@ -63,6 +63,22 @@ Django can be integrated by middleware. Add ``DjangoOpenAPIMiddleware`` to your OPENAPI_SPEC = Spec.from_dict(spec_dict) +You can skip response validation process: by setting ``OPENAPI_RESPONSE_CLS`` to ``None`` + +.. code-block:: python + :emphasize-lines: 10 + + # settings.py + from openapi_core import Spec + + MIDDLEWARE = [ + # ... + 'openapi_core.contrib.django.middlewares.DjangoOpenAPIMiddleware', + ] + + OPENAPI_SPEC = Spec.from_dict(spec_dict) + OPENAPI_RESPONSE_CLS = None + After that you have access to unmarshal result object with all validated request data from Django view through request object. .. code-block:: python @@ -146,6 +162,23 @@ Additional customization parameters can be passed to the middleware. middleware=[openapi_middleware], ) +You can skip response validation process: by setting ``response_cls`` to ``None`` + +.. code-block:: python + :emphasize-lines: 5 + + from openapi_core.contrib.falcon.middlewares import FalconOpenAPIMiddleware + + openapi_middleware = FalconOpenAPIMiddleware.from_spec( + spec, + response_cls=None, + ) + + app = falcon.App( + # ... + middleware=[openapi_middleware], + ) + After that you will have access to validation result object with all validated request data from Falcon view through request context. .. code-block:: python @@ -221,6 +254,18 @@ Additional customization parameters can be passed to the decorator. extra_format_validators=extra_format_validators, ) +You can skip response validation process: by setting ``response_cls`` to ``None`` + +.. code-block:: python + :emphasize-lines: 5 + + from openapi_core.contrib.flask.decorators import FlaskOpenAPIViewDecorator + + openapi = FlaskOpenAPIViewDecorator.from_spec( + spec, + response_cls=None, + ) + If you want to decorate class based view you can use the decorators attribute: .. code-block:: python diff --git a/openapi_core/contrib/django/middlewares.py b/openapi_core/contrib/django/middlewares.py index 5950cff6..6998b9be 100644 --- a/openapi_core/contrib/django/middlewares.py +++ b/openapi_core/contrib/django/middlewares.py @@ -18,8 +18,8 @@ class DjangoOpenAPIMiddleware: - request_class = DjangoOpenAPIRequest - response_class = DjangoOpenAPIResponse + request_cls = DjangoOpenAPIRequest + response_cls = DjangoOpenAPIResponse errors_handler = DjangoOpenAPIErrorsHandler() def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): @@ -28,6 +28,9 @@ def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): if not hasattr(settings, "OPENAPI_SPEC"): raise ImproperlyConfigured("OPENAPI_SPEC not defined in settings") + if hasattr(settings, "OPENAPI_RESPONSE_CLS"): + self.response_cls = settings.OPENAPI_RESPONSE_CLS + self.processor = UnmarshallingProcessor(settings.OPENAPI_SPEC) def __call__(self, request: HttpRequest) -> HttpResponse: @@ -39,6 +42,8 @@ def __call__(self, request: HttpRequest) -> HttpResponse: request.openapi = req_result response = self.get_response(request) + if self.response_cls is None: + return response openapi_response = self._get_openapi_response(response) resp_result = self.processor.process_response( openapi_request, openapi_response @@ -64,9 +69,10 @@ def _handle_response_errors( def _get_openapi_request( self, request: HttpRequest ) -> DjangoOpenAPIRequest: - return self.request_class(request) + return self.request_cls(request) def _get_openapi_response( self, response: HttpResponse ) -> DjangoOpenAPIResponse: - return self.response_class(response) + assert self.response_cls is not None + return self.response_cls(response) diff --git a/openapi_core/contrib/falcon/middlewares.py b/openapi_core/contrib/falcon/middlewares.py index 752dd85f..f30c7f59 100644 --- a/openapi_core/contrib/falcon/middlewares.py +++ b/openapi_core/contrib/falcon/middlewares.py @@ -20,8 +20,8 @@ class FalconOpenAPIMiddleware(UnmarshallingProcessor): - request_class = FalconOpenAPIRequest - response_class = FalconOpenAPIResponse + request_cls = FalconOpenAPIRequest + response_cls = FalconOpenAPIResponse errors_handler = FalconOpenAPIErrorsHandler() def __init__( @@ -29,8 +29,8 @@ def __init__( spec: Spec, request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, - request_class: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest, - response_class: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse, + request_cls: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest, + response_cls: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse, errors_handler: Optional[FalconOpenAPIErrorsHandler] = None, **unmarshaller_kwargs: Any, ): @@ -40,8 +40,8 @@ def __init__( response_unmarshaller_cls=response_unmarshaller_cls, **unmarshaller_kwargs, ) - self.request_class = request_class or self.request_class - self.response_class = response_class or self.response_class + self.request_cls = request_cls or self.request_cls + self.response_cls = response_cls or self.response_cls self.errors_handler = errors_handler or self.errors_handler @classmethod @@ -50,8 +50,8 @@ def from_spec( spec: Spec, request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, - request_class: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest, - response_class: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse, + request_cls: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest, + response_cls: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse, errors_handler: Optional[FalconOpenAPIErrorsHandler] = None, **unmarshaller_kwargs: Any, ) -> "FalconOpenAPIMiddleware": @@ -59,8 +59,8 @@ def from_spec( spec, request_unmarshaller_cls=request_unmarshaller_cls, response_unmarshaller_cls=response_unmarshaller_cls, - request_class=request_class, - response_class=response_class, + request_cls=request_cls, + response_cls=response_cls, errors_handler=errors_handler, **unmarshaller_kwargs, ) @@ -74,6 +74,8 @@ def process_request(self, req: Request, resp: Response) -> None: # type: ignore def process_response( # type: ignore self, req: Request, resp: Response, resource: Any, req_succeeded: bool ) -> None: + if self.response_cls is None: + return resp openapi_req = self._get_openapi_request(req) openapi_resp = self._get_openapi_response(resp) resp.context.openapi = super().process_response( @@ -101,9 +103,10 @@ def _handle_response_errors( return self.errors_handler.handle(req, resp, response_result.errors) def _get_openapi_request(self, request: Request) -> FalconOpenAPIRequest: - return self.request_class(request) + return self.request_cls(request) def _get_openapi_response( self, response: Response ) -> FalconOpenAPIResponse: - return self.response_class(response) + assert self.response_cls is not None + return self.response_cls(response) diff --git a/openapi_core/contrib/flask/__init__.py b/openapi_core/contrib/flask/__init__.py index b8061df1..c7d0bf2b 100644 --- a/openapi_core/contrib/flask/__init__.py +++ b/openapi_core/contrib/flask/__init__.py @@ -1,7 +1,9 @@ +from openapi_core.contrib.flask.decorators import FlaskOpenAPIViewDecorator from openapi_core.contrib.flask.requests import FlaskOpenAPIRequest from openapi_core.contrib.flask.responses import FlaskOpenAPIResponse __all__ = [ + "FlaskOpenAPIViewDecorator", "FlaskOpenAPIRequest", "FlaskOpenAPIResponse", ] diff --git a/openapi_core/contrib/flask/decorators.py b/openapi_core/contrib/flask/decorators.py index 1d360ae4..7c71ad24 100644 --- a/openapi_core/contrib/flask/decorators.py +++ b/openapi_core/contrib/flask/decorators.py @@ -30,8 +30,10 @@ def __init__( spec: Spec, request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, - request_class: Type[FlaskOpenAPIRequest] = FlaskOpenAPIRequest, - response_class: Type[FlaskOpenAPIResponse] = FlaskOpenAPIResponse, + request_cls: Type[FlaskOpenAPIRequest] = FlaskOpenAPIRequest, + response_cls: Optional[ + Type[FlaskOpenAPIResponse] + ] = FlaskOpenAPIResponse, request_provider: Type[FlaskRequestProvider] = FlaskRequestProvider, openapi_errors_handler: Type[ FlaskOpenAPIErrorsHandler @@ -44,8 +46,8 @@ def __init__( response_unmarshaller_cls=response_unmarshaller_cls, **unmarshaller_kwargs, ) - self.request_class = request_class - self.response_class = response_class + self.request_cls = request_cls + self.response_cls = response_cls self.request_provider = request_provider self.openapi_errors_handler = openapi_errors_handler @@ -60,6 +62,8 @@ def decorated(*args: Any, **kwargs: Any) -> Response: response = self._handle_request_view( request_result, view, *args, **kwargs ) + if self.response_cls is None: + return response openapi_response = self._get_openapi_response(response) response_result = self.process_response( openapi_request, openapi_response @@ -96,12 +100,13 @@ def _get_request(self) -> Request: return request def _get_openapi_request(self, request: Request) -> FlaskOpenAPIRequest: - return self.request_class(request) + return self.request_cls(request) def _get_openapi_response( self, response: Response ) -> FlaskOpenAPIResponse: - return self.response_class(response) + assert self.response_cls is not None + return self.response_cls(response) @classmethod def from_spec( @@ -109,8 +114,8 @@ def from_spec( spec: Spec, request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, - request_class: Type[FlaskOpenAPIRequest] = FlaskOpenAPIRequest, - response_class: Type[FlaskOpenAPIResponse] = FlaskOpenAPIResponse, + request_cls: Type[FlaskOpenAPIRequest] = FlaskOpenAPIRequest, + response_cls: Type[FlaskOpenAPIResponse] = FlaskOpenAPIResponse, request_provider: Type[FlaskRequestProvider] = FlaskRequestProvider, openapi_errors_handler: Type[ FlaskOpenAPIErrorsHandler @@ -121,8 +126,8 @@ def from_spec( spec, request_unmarshaller_cls=request_unmarshaller_cls, response_unmarshaller_cls=response_unmarshaller_cls, - request_class=request_class, - response_class=response_class, + request_cls=request_cls, + response_cls=response_cls, request_provider=request_provider, openapi_errors_handler=openapi_errors_handler, **unmarshaller_kwargs, diff --git a/tests/integration/contrib/django/data/v3.0/djangoproject/tags/__init__.py b/tests/integration/contrib/django/data/v3.0/djangoproject/tags/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/contrib/django/data/v3.0/djangoproject/tags/views.py b/tests/integration/contrib/django/data/v3.0/djangoproject/tags/views.py new file mode 100644 index 00000000..d822b4ff --- /dev/null +++ b/tests/integration/contrib/django/data/v3.0/djangoproject/tags/views.py @@ -0,0 +1,13 @@ +from django.http import HttpResponse +from rest_framework.views import APIView + + +class TagListView(APIView): + def get(self, request): + assert request.openapi + assert not request.openapi.errors + return HttpResponse("success") + + @staticmethod + def get_extra_actions(): + return [] diff --git a/tests/integration/contrib/django/data/v3.0/djangoproject/urls.py b/tests/integration/contrib/django/data/v3.0/djangoproject/urls.py index 9a91da5c..3b4d7329 100644 --- a/tests/integration/contrib/django/data/v3.0/djangoproject/urls.py +++ b/tests/integration/contrib/django/data/v3.0/djangoproject/urls.py @@ -16,7 +16,9 @@ from django.contrib import admin from django.urls import include from django.urls import path -from djangoproject.pets import views +from djangoproject.pets.views import PetDetailView +from djangoproject.pets.views import PetListView +from djangoproject.tags.views import TagListView urlpatterns = [ path("admin/", admin.site.urls), @@ -26,12 +28,17 @@ ), path( "v1/pets", - views.PetListView.as_view(), + PetListView.as_view(), name="pet_list_view", ), path( "v1/pets/", - views.PetDetailView.as_view(), + PetDetailView.as_view(), name="pet_detail_view", ), + path( + "v1/tags", + TagListView.as_view(), + name="tag_list_view", + ), ] diff --git a/tests/integration/contrib/django/test_django_project.py b/tests/integration/contrib/django/test_django_project.py index 38e49870..ed429071 100644 --- a/tests/integration/contrib/django/test_django_project.py +++ b/tests/integration/contrib/django/test_django_project.py @@ -5,6 +5,7 @@ from unittest import mock import pytest +from django.test.utils import override_settings class BaseTestDjangoProject: @@ -372,3 +373,25 @@ def test_post_valid(self, api_client): assert response.status_code == 201 assert not response.content + + +class TestDRFTagListView(BaseTestDRF): + def test_get_response_invalid(self, client): + headers = { + "HTTP_AUTHORIZATION": "Basic testuser", + "HTTP_HOST": "petstore.swagger.io", + } + response = client.get("/v1/tags", **headers) + + assert response.status_code == 415 + + def test_get_skip_response_validation(self, client): + headers = { + "HTTP_AUTHORIZATION": "Basic testuser", + "HTTP_HOST": "petstore.swagger.io", + } + with override_settings(OPENAPI_RESPONSE_CLS=None): + response = client.get("/v1/tags", **headers) + + assert response.status_code == 200 + assert response.content == b"success" diff --git a/tests/integration/contrib/flask/conftest.py b/tests/integration/contrib/flask/conftest.py new file mode 100644 index 00000000..400b1cf7 --- /dev/null +++ b/tests/integration/contrib/flask/conftest.py @@ -0,0 +1,37 @@ +import pytest +from flask import Flask + + +@pytest.fixture(scope="session") +def spec(factory): + specfile = "contrib/flask/data/v3.0/flask_factory.yaml" + return factory.spec_from_file(specfile) + + +@pytest.fixture +def app(app_factory): + return app_factory() + + +@pytest.fixture +def client(client_factory, app): + return client_factory(app) + + +@pytest.fixture(scope="session") +def client_factory(): + def create(app): + return app.test_client() + + return create + + +@pytest.fixture(scope="session") +def app_factory(): + def create(root_path=None): + app = Flask("__main__", root_path=root_path) + app.config["DEBUG"] = True + app.config["TESTING"] = True + return app + + return create diff --git a/tests/integration/contrib/flask/data/v3.0/flask_factory.yaml b/tests/integration/contrib/flask/data/v3.0/flask_factory.yaml index 5d219ed3..17a195db 100644 --- a/tests/integration/contrib/flask/data/v3.0/flask_factory.yaml +++ b/tests/integration/contrib/flask/data/v3.0/flask_factory.yaml @@ -13,6 +13,12 @@ paths: description: the ID of the resource to retrieve schema: type: integer + - name: q + in: query + required: false + description: query key + schema: + type: string get: responses: 404: @@ -58,3 +64,54 @@ paths: type: string message: type: string + post: + requestBody: + description: request data + required: True + content: + application/json: + schema: + type: object + required: + - param1 + properties: + param1: + type: integer + responses: + 200: + description: Return the resource. + content: + application/json: + schema: + type: object + required: + - data + properties: + data: + type: string + headers: + X-Rate-Limit: + description: Rate limit + schema: + type: integer + required: true + default: + description: Return errors. + content: + application/json: + schema: + type: object + required: + - errors + properties: + errors: + type: array + items: + type: object + properties: + title: + type: string + code: + type: string + message: + type: string diff --git a/tests/integration/contrib/flask/test_flask_decorator.py b/tests/integration/contrib/flask/test_flask_decorator.py index 19bea449..9dcf8093 100644 --- a/tests/integration/contrib/flask/test_flask_decorator.py +++ b/tests/integration/contrib/flask/test_flask_decorator.py @@ -7,57 +7,37 @@ from openapi_core.datatypes import Parameters -class TestFlaskOpenAPIDecorator: - view_response_callable = None - - @pytest.fixture - def spec(self, factory): - specfile = "contrib/flask/data/v3.0/flask_factory.yaml" - return factory.spec_from_file(specfile) - - @pytest.fixture - def decorator(self, spec): - return FlaskOpenAPIViewDecorator.from_spec(spec) - - @pytest.fixture - def app(self): - app = Flask("__main__") - app.config["DEBUG"] = True - app.config["TESTING"] = True - return app +@pytest.fixture(scope="session") +def decorator_factory(spec): + def create(**kwargs): + return FlaskOpenAPIViewDecorator.from_spec(spec, **kwargs) - @pytest.fixture - def client(self, app): - with app.test_client() as client: - with app.app_context(): - yield client + return create - @pytest.fixture - def view_response(self): - def view_response(*args, **kwargs): - return self.view_response_callable(*args, **kwargs) - return view_response +@pytest.fixture(scope="session") +def view_factory(decorator_factory): + def create( + app, path, methods=None, view_response_callable=None, decorator=None + ): + decorator = decorator or decorator_factory() - @pytest.fixture(autouse=True) - def details_view(self, app, decorator, view_response): - @app.route("/browse//", methods=["GET", "POST"]) + @app.route(path, methods=methods) @decorator - def browse_details(*args, **kwargs): - return view_response(*args, **kwargs) + def view(*args, **kwargs): + return view_response_callable(*args, **kwargs) - return browse_details + return view + + return create - @pytest.fixture(autouse=True) - def list_view(self, app, decorator, view_response): - @app.route("/browse/") - @decorator - def browse_list(*args, **kwargs): - return view_response(*args, **kwargs) - return browse_list +class TestFlaskOpenAPIDecorator: + @pytest.fixture + def decorator(self, decorator_factory): + return decorator_factory() - def test_invalid_content_type(self, client): + def test_invalid_content_type(self, client, view_factory, app, decorator): def view_response_callable(*args, **kwargs): from flask.globals import request @@ -72,7 +52,13 @@ def view_response_callable(*args, **kwargs): resp.headers["X-Rate-Limit"] = "12" return resp - self.view_response_callable = view_response_callable + view_factory( + app, + "/browse//", + ["GET", "PUT"], + view_response_callable=view_response_callable, + decorator=decorator, + ) result = client.get("/browse/12/") assert result.json == { @@ -91,7 +77,14 @@ def view_response_callable(*args, **kwargs): ] } - def test_server_error(self, client): + def test_server_error(self, client, view_factory, app, decorator): + view_factory( + app, + "/browse//", + ["GET", "PUT"], + view_response_callable=None, + decorator=decorator, + ) result = client.get("/browse/12/", base_url="https://localhost") expected_data = { @@ -112,8 +105,15 @@ def test_server_error(self, client): assert result.status_code == 400 assert result.json == expected_data - def test_operation_error(self, client): - result = client.post("/browse/12/") + def test_operation_error(self, client, view_factory, app, decorator): + view_factory( + app, + "/browse//", + ["GET", "PUT"], + view_response_callable=None, + decorator=decorator, + ) + result = client.put("/browse/12/") expected_data = { "errors": [ @@ -124,7 +124,7 @@ def test_operation_error(self, client): ), "status": 405, "title": ( - "Operation post not found for " + "Operation put not found for " "http://localhost/browse/{id}/" ), } @@ -133,7 +133,13 @@ def test_operation_error(self, client): assert result.status_code == 405 assert result.json == expected_data - def test_path_error(self, client): + def test_path_error(self, client, view_factory, app, decorator): + view_factory( + app, + "/browse/", + view_response_callable=None, + decorator=decorator, + ) result = client.get("/browse/") expected_data = { @@ -153,7 +159,14 @@ def test_path_error(self, client): assert result.status_code == 404 assert result.json == expected_data - def test_endpoint_error(self, client): + def test_endpoint_error(self, client, view_factory, app, decorator): + view_factory( + app, + "/browse//", + ["GET", "PUT"], + view_response_callable=None, + decorator=decorator, + ) result = client.get("/browse/invalidparameter/") expected_data = { @@ -173,7 +186,7 @@ def test_endpoint_error(self, client): } assert result.json == expected_data - def test_response_object_valid(self, client): + def test_response_object_valid(self, client, view_factory, app, decorator): def view_response_callable(*args, **kwargs): from flask.globals import request @@ -188,7 +201,13 @@ def view_response_callable(*args, **kwargs): resp.headers["X-Rate-Limit"] = "12" return resp - self.view_response_callable = view_response_callable + view_factory( + app, + "/browse//", + ["GET", "PUT"], + view_response_callable=view_response_callable, + decorator=decorator, + ) result = client.get("/browse/12/") @@ -197,6 +216,35 @@ def view_response_callable(*args, **kwargs): "data": "data", } + def test_response_skip_validation( + self, client, view_factory, app, decorator_factory + ): + def view_response_callable(*args, **kwargs): + from flask.globals import request + + assert request.openapi + assert not request.openapi.errors + assert request.openapi.parameters == Parameters( + path={ + "id": 12, + } + ) + return make_response("success", 200) + + decorator = decorator_factory(response_cls=None) + view_factory( + app, + "/browse//", + ["GET", "PUT"], + view_response_callable=view_response_callable, + decorator=decorator, + ) + + result = client.get("/browse/12/") + + assert result.status_code == 200 + assert result.text == "success" + @pytest.mark.parametrize( "response,expected_status,expected_headers", [ @@ -217,7 +265,14 @@ def view_response_callable(*args, **kwargs): ], ) def test_tuple_valid( - self, client, response, expected_status, expected_headers + self, + client, + view_factory, + app, + decorator, + response, + expected_status, + expected_headers, ): def view_response_callable(*args, **kwargs): from flask.globals import request @@ -231,7 +286,13 @@ def view_response_callable(*args, **kwargs): ) return response - self.view_response_callable = view_response_callable + view_factory( + app, + "/browse//", + ["GET", "PUT"], + view_response_callable=view_response_callable, + decorator=decorator, + ) result = client.get("/browse/12/") diff --git a/tests/integration/contrib/flask/test_flask_validator.py b/tests/integration/contrib/flask/test_flask_validator.py index 1f4a1a4f..a2fd4332 100644 --- a/tests/integration/contrib/flask/test_flask_validator.py +++ b/tests/integration/contrib/flask/test_flask_validator.py @@ -9,22 +9,9 @@ from openapi_core.contrib.flask import FlaskOpenAPIRequest -class TestWerkzeugOpenAPIValidation: - @pytest.fixture - def spec(self, factory): - specfile = "contrib/requests/data/v3.0/requests_factory.yaml" - return factory.spec_from_file(specfile) - - @pytest.fixture - def app(self): - app = Flask("__main__", root_path="/browse") - app.config["DEBUG"] = True - app.config["TESTING"] = True - return app - - @pytest.fixture - def details_view_func(self, spec): - def datails_browse(id): +class TestFlaskOpenAPIValidation: + def test_request_validator_root_path(self, spec, app_factory): + def details_view_func(id): from flask import request openapi_request = FlaskOpenAPIRequest(request) @@ -42,26 +29,18 @@ def datails_browse(id): else: return Response("Not Found", status=404) - return datails_browse - - @pytest.fixture(autouse=True) - def view(self, app, details_view_func): + app = app_factory(root_path="/browse") app.add_url_rule( "//", view_func=details_view_func, methods=["POST"], ) - - @pytest.fixture - def client(self, app): - return FlaskClient(app) - - def test_request_validator_root_path(self, client): query_string = { "q": "string", } headers = {"content-type": "application/json"} data = {"param1": 1} + client = FlaskClient(app) result = client.post( "/12/", base_url="http://localhost/browse", diff --git a/tests/integration/contrib/flask/test_flask_views.py b/tests/integration/contrib/flask/test_flask_views.py index 5a253ab5..2d786e88 100644 --- a/tests/integration/contrib/flask/test_flask_views.py +++ b/tests/integration/contrib/flask/test_flask_views.py @@ -6,62 +6,50 @@ from openapi_core.contrib.flask.views import FlaskOpenAPIView -class TestFlaskOpenAPIView: - view_response = None - - @pytest.fixture - def spec(self, factory): - specfile = "contrib/flask/data/v3.0/flask_factory.yaml" - return factory.spec_from_file(specfile) - - @pytest.fixture - def app(self): - app = Flask("__main__") - app.config["DEBUG"] = True - app.config["TESTING"] = True - return app - - @pytest.fixture - def client(self, app): - with app.test_client() as client: - with app.app_context(): - yield client - - @pytest.fixture - def details_view_func(self, spec): - outer = self - - class MyDetailsView(FlaskOpenAPIView): - def get(self, id): - return outer.view_response - - def post(self, id): - return outer.view_response - - return MyDetailsView.as_view( - "browse_details", spec, extra_media_type_deserializers={} +@pytest.fixture(scope="session") +def view_factory(): + def create( + spec, + methods=None, + extra_media_type_deserializers=None, + extra_format_validators=None, + ): + if methods is None: + + def get(view, id): + return make_response("success", 200) + + methods = { + "get": get, + } + MyView = type("MyView", (FlaskOpenAPIView,), methods) + extra_media_type_deserializers = extra_media_type_deserializers or {} + extra_format_validators = extra_format_validators or {} + return MyView.as_view( + "myview", + spec, + extra_media_type_deserializers=extra_media_type_deserializers, + extra_format_validators=extra_format_validators, ) - @pytest.fixture - def list_view_func(self, spec): - outer = self + return create - class MyListView(FlaskOpenAPIView): - def get(self): - return outer.view_response - return MyListView.as_view( - "browse_list", spec, extra_format_validators={} - ) +class TestFlaskOpenAPIView: + @pytest.fixture + def client(self, client_factory, app): + client = client_factory(app) + with app.app_context(): + yield client - @pytest.fixture(autouse=True) - def view(self, app, details_view_func, list_view_func): - app.add_url_rule("/browse//", view_func=details_view_func) - app.add_url_rule("/browse/", view_func=list_view_func) + def test_invalid_content_type(self, client, app, spec, view_factory): + def get(view, id): + view_response = make_response("success", 200) + view_response.headers["X-Rate-Limit"] = "12" + return view_response - def test_invalid_content_type(self, client): - self.view_response = make_response("success", 200) - self.view_response.headers["X-Rate-Limit"] = "12" + view_func = view_factory(spec, {"get": get}) + app.add_url_rule("/browse//", view_func=view_func) result = client.get("/browse/12/") @@ -82,7 +70,10 @@ def test_invalid_content_type(self, client): ] } - def test_server_error(self, client): + def test_server_error(self, client, app, spec, view_factory): + view_func = view_factory(spec) + app.add_url_rule("/browse//", view_func=view_func) + result = client.get("/browse/12/", base_url="https://localhost") expected_data = { @@ -103,8 +94,14 @@ def test_server_error(self, client): assert result.status_code == 400 assert result.json == expected_data - def test_operation_error(self, client): - result = client.post("/browse/12/") + def test_operation_error(self, client, app, spec, view_factory): + def put(view, id): + return make_response("success", 200) + + view_func = view_factory(spec, {"put": put}) + app.add_url_rule("/browse//", view_func=view_func) + + result = client.put("/browse/12/") expected_data = { "errors": [ @@ -115,7 +112,7 @@ def test_operation_error(self, client): ), "status": 405, "title": ( - "Operation post not found for " + "Operation put not found for " "http://localhost/browse/{id}/" ), } @@ -124,7 +121,10 @@ def test_operation_error(self, client): assert result.status_code == 405 assert result.json == expected_data - def test_path_error(self, client): + def test_path_error(self, client, app, spec, view_factory): + view_func = view_factory(spec) + app.add_url_rule("/browse/", view_func=view_func) + result = client.get("/browse/") expected_data = { @@ -144,7 +144,10 @@ def test_path_error(self, client): assert result.status_code == 404 assert result.json == expected_data - def test_endpoint_error(self, client): + def test_endpoint_error(self, client, app, spec, view_factory): + view_func = view_factory(spec) + app.add_url_rule("/browse//", view_func=view_func) + result = client.get("/browse/invalidparameter/") expected_data = { @@ -165,8 +168,12 @@ def test_endpoint_error(self, client): assert result.status_code == 400 assert result.json == expected_data - def test_missing_required_header(self, client): - self.view_response = jsonify(data="data") + def test_missing_required_header(self, client, app, spec, view_factory): + def get(view, id): + return jsonify(data="data") + + view_func = view_factory(spec, {"get": get}) + app.add_url_rule("/browse//", view_func=view_func) result = client.get("/browse/12/") @@ -185,9 +192,14 @@ def test_missing_required_header(self, client): assert result.status_code == 400 assert result.json == expected_data - def test_valid(self, client): - self.view_response = jsonify(data="data") - self.view_response.headers["X-Rate-Limit"] = "12" + def test_valid(self, client, app, spec, view_factory): + def get(view, id): + resp = jsonify(data="data") + resp.headers["X-Rate-Limit"] = "12" + return resp + + view_func = view_factory(spec, {"get": get}) + app.add_url_rule("/browse//", view_func=view_func) result = client.get("/browse/12/")