diff --git a/openapi_core/shortcuts.py b/openapi_core/shortcuts.py index 50c60103..00717ffa 100644 --- a/openapi_core/shortcuts.py +++ b/openapi_core/shortcuts.py @@ -5,15 +5,17 @@ from typing import Union from jsonschema_path import SchemaPath +from openapi_spec_validator.versions import consts as versions +from openapi_spec_validator.versions.datatypes import SpecVersion +from openapi_spec_validator.versions.exceptions import OpenAPIVersionNotFound +from openapi_spec_validator.versions.shortcuts import get_spec_version from openapi_core.exceptions import SpecError -from openapi_core.finders import SpecClasses -from openapi_core.finders import SpecFinder -from openapi_core.finders import SpecVersion from openapi_core.protocols import Request from openapi_core.protocols import Response from openapi_core.protocols import WebhookRequest from openapi_core.spec import Spec +from openapi_core.types import SpecClasses from openapi_core.unmarshalling.request import V30RequestUnmarshaller from openapi_core.unmarshalling.request import V31RequestUnmarshaller from openapi_core.unmarshalling.request import V31WebhookRequestUnmarshaller @@ -63,8 +65,8 @@ AnyRequest = Union[Request, WebhookRequest] -SPECS: Dict[SpecVersion, SpecClasses] = { - SpecVersion("openapi", "3.0"): SpecClasses( +SPEC2CLASSES: Dict[SpecVersion, SpecClasses] = { + versions.OPENAPIV30: SpecClasses( V30RequestValidator, V30ResponseValidator, None, @@ -74,7 +76,7 @@ None, None, ), - SpecVersion("openapi", "3.1"): SpecClasses( + versions.OPENAPIV31: SpecClasses( V31RequestValidator, V31ResponseValidator, V31WebhookRequestValidator, @@ -88,7 +90,15 @@ def get_classes(spec: SchemaPath) -> SpecClasses: - return SpecFinder(SPECS).get_classes(spec) + try: + spec_version = get_spec_version(spec.contents()) + # backward compatibility + except OpenAPIVersionNotFound: + raise SpecError("Spec schema version not detected") + try: + return SPEC2CLASSES[spec_version] + except KeyError: + raise SpecError("Spec schema version not supported") def unmarshal_apicall_request( diff --git a/openapi_core/finders.py b/openapi_core/types.py similarity index 74% rename from openapi_core/finders.py rename to openapi_core/types.py index 3cb87b5c..9d9b1bc8 100644 --- a/openapi_core/finders.py +++ b/openapi_core/types.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from typing import Mapping from typing import NamedTuple from typing import Optional @@ -21,12 +22,8 @@ from openapi_core.validation.validators import BaseValidator -class SpecVersion(NamedTuple): - name: str - version: str - - -class SpecClasses(NamedTuple): +@dataclass +class SpecClasses: request_validator_cls: RequestValidatorType response_validator_cls: ResponseValidatorType webhook_request_validator_cls: Optional[WebhookRequestValidatorType] @@ -37,14 +34,3 @@ class SpecClasses(NamedTuple): webhook_response_unmarshaller_cls: Optional[ WebhookResponseUnmarshallerType ] - - -class SpecFinder: - def __init__(self, specs: Mapping[SpecVersion, SpecClasses]) -> None: - self.specs = specs - - def get_classes(self, spec: SchemaPath) -> SpecClasses: - for v, classes in self.specs.items(): - if v.name in spec and spec[v.name].startswith(v.version): - return classes - raise SpecError("Spec schema version not detected") diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 736eb9ab..63fad9df 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -2,6 +2,11 @@ from jsonschema_path import SchemaPath +@pytest.fixture +def spec_v20(): + return SchemaPath.from_dict({"swagger": "2.0"}) + + @pytest.fixture def spec_v30(): return SchemaPath.from_dict({"openapi": "3.0.0"}) diff --git a/tests/unit/test_shortcuts.py b/tests/unit/test_shortcuts.py index 170c4cbf..1d83c569 100644 --- a/tests/unit/test_shortcuts.py +++ b/tests/unit/test_shortcuts.py @@ -97,6 +97,12 @@ def test_spec_not_detected(self, spec_invalid): with pytest.raises(SpecError): unmarshal_apicall_request(request, spec=spec_invalid) + def test_spec_not_supported(self, spec_v20): + request = mock.Mock(spec=Request) + + with pytest.raises(SpecError): + unmarshal_apicall_request(request, spec=spec_v20) + def test_request_type_invalid(self, spec_v31): request = mock.sentinel.request @@ -124,6 +130,12 @@ def test_spec_not_detected(self, spec_invalid): with pytest.raises(SpecError): unmarshal_webhook_request(request, spec=spec_invalid) + def test_spec_not_supported(self, spec_v20): + request = mock.Mock(spec=WebhookRequest) + + with pytest.raises(SpecError): + unmarshal_webhook_request(request, spec=spec_v20) + def test_request_type_invalid(self, spec_v31): request = mock.sentinel.request @@ -169,6 +181,12 @@ def test_spec_not_detected(self, spec_invalid): with pytest.raises(SpecError): unmarshal_request(request, spec=spec_invalid) + def test_spec_not_supported(self, spec_v20): + request = mock.Mock(spec=Request) + + with pytest.raises(SpecError): + unmarshal_request(request, spec=spec_v20) + def test_request_type_invalid(self, spec_v31): request = mock.sentinel.request @@ -257,6 +275,13 @@ def test_spec_not_detected(self, spec_invalid): with pytest.raises(SpecError): unmarshal_apicall_response(request, response, spec=spec_invalid) + def test_spec_not_supported(self, spec_v20): + request = mock.Mock(spec=Request) + response = mock.Mock(spec=Response) + + with pytest.raises(SpecError): + unmarshal_apicall_response(request, response, spec=spec_v20) + def test_request_type_invalid(self, spec_v31): request = mock.sentinel.request response = mock.Mock(spec=Response) @@ -297,6 +322,13 @@ def test_spec_not_detected(self, spec_invalid): with pytest.raises(SpecError): unmarshal_response(request, response, spec=spec_invalid) + def test_spec_not_supported(self, spec_v20): + request = mock.Mock(spec=Request) + response = mock.Mock(spec=Response) + + with pytest.raises(SpecError): + unmarshal_response(request, response, spec=spec_v20) + def test_request_type_invalid(self, spec_v31): request = mock.sentinel.request response = mock.Mock(spec=Response) @@ -404,6 +436,13 @@ def test_spec_not_detected(self, spec_invalid): with pytest.raises(SpecError): unmarshal_webhook_response(request, response, spec=spec_invalid) + def test_spec_not_supported(self, spec_v20): + request = mock.Mock(spec=WebhookRequest) + response = mock.Mock(spec=Response) + + with pytest.raises(SpecError): + unmarshal_webhook_response(request, response, spec=spec_v20) + def test_request_type_invalid(self, spec_v31): request = mock.sentinel.request response = mock.Mock(spec=Response) @@ -463,6 +502,12 @@ def test_spec_not_detected(self, spec_invalid): with pytest.raises(SpecError): validate_apicall_request(request, spec=spec_invalid) + def test_spec_not_supported(self, spec_v20): + request = mock.Mock(spec=Request) + + with pytest.raises(SpecError): + validate_apicall_request(request, spec=spec_v20) + def test_request_type_invalid(self, spec_v31): request = mock.sentinel.request @@ -502,6 +547,12 @@ def test_spec_not_detected(self, spec_invalid): with pytest.raises(SpecError): validate_webhook_request(request, spec=spec_invalid) + def test_spec_not_supported(self, spec_v20): + request = mock.Mock(spec=WebhookRequest) + + with pytest.raises(SpecError): + validate_webhook_request(request, spec=spec_v20) + def test_request_type_invalid(self, spec_v31): request = mock.sentinel.request @@ -548,6 +599,13 @@ def test_spec_not_detected(self, spec_invalid): with pytest.warns(DeprecationWarning): validate_request(request, spec=spec_invalid) + def test_spec_not_detected(self, spec_v20): + request = mock.Mock(spec=Request) + + with pytest.raises(SpecError): + with pytest.warns(DeprecationWarning): + validate_request(request, spec=spec_v20) + def test_request_type_invalid(self, spec_v31): request = mock.sentinel.request @@ -705,6 +763,13 @@ def test_spec_not_detected(self, spec_invalid): with pytest.raises(SpecError): validate_apicall_response(request, response, spec=spec_invalid) + def test_spec_not_supported(self, spec_v20): + request = mock.Mock(spec=Request) + response = mock.Mock(spec=Response) + + with pytest.raises(SpecError): + validate_apicall_response(request, response, spec=spec_v20) + def test_request_type_invalid(self, spec_v31): request = mock.sentinel.request response = mock.Mock(spec=Response) @@ -758,6 +823,13 @@ def test_spec_not_detected(self, spec_invalid): with pytest.raises(SpecError): validate_webhook_response(request, response, spec=spec_invalid) + def test_spec_not_supported(self, spec_v20): + request = mock.Mock(spec=WebhookRequest) + response = mock.Mock(spec=Response) + + with pytest.raises(SpecError): + validate_webhook_response(request, response, spec=spec_v20) + def test_request_type_invalid(self, spec_v31): request = mock.sentinel.request response = mock.Mock(spec=Response) @@ -819,6 +891,14 @@ def test_spec_not_detected(self, spec_invalid): with pytest.warns(DeprecationWarning): validate_response(request, response, spec=spec_invalid) + def test_spec_not_supported(self, spec_v20): + request = mock.Mock(spec=Request) + response = mock.Mock(spec=Response) + + with pytest.raises(SpecError): + with pytest.warns(DeprecationWarning): + validate_response(request, response, spec=spec_v20) + def test_request_type_invalid(self, spec_v31): request = mock.sentinel.request response = mock.Mock(spec=Response)