Skip to content

Commit

Permalink
Merge pull request #691 from python-openapi/feature/use-openapi-spec-…
Browse files Browse the repository at this point in the history
…validator-spec-version-finder

Use openapi-spec-validator spec version finder
  • Loading branch information
p1c2u authored Oct 13, 2023
2 parents 0f5ac8e + 860ca0a commit df1f1e1
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 24 deletions.
24 changes: 17 additions & 7 deletions openapi_core/shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
from typing import Union

from jsonschema_path import SchemaPath
from openapi_spec_validator.versions import consts as versions
from openapi_spec_validator.versions.datatypes import SpecVersion
from openapi_spec_validator.versions.exceptions import OpenAPIVersionNotFound
from openapi_spec_validator.versions.shortcuts import get_spec_version

from openapi_core.exceptions import SpecError
from openapi_core.finders import SpecClasses
from openapi_core.finders import SpecFinder
from openapi_core.finders import SpecVersion
from openapi_core.protocols import Request
from openapi_core.protocols import Response
from openapi_core.protocols import WebhookRequest
from openapi_core.spec import Spec
from openapi_core.types import SpecClasses
from openapi_core.unmarshalling.request import V30RequestUnmarshaller
from openapi_core.unmarshalling.request import V31RequestUnmarshaller
from openapi_core.unmarshalling.request import V31WebhookRequestUnmarshaller
Expand Down Expand Up @@ -63,8 +65,8 @@

AnyRequest = Union[Request, WebhookRequest]

SPECS: Dict[SpecVersion, SpecClasses] = {
SpecVersion("openapi", "3.0"): SpecClasses(
SPEC2CLASSES: Dict[SpecVersion, SpecClasses] = {
versions.OPENAPIV30: SpecClasses(
V30RequestValidator,
V30ResponseValidator,
None,
Expand All @@ -74,7 +76,7 @@
None,
None,
),
SpecVersion("openapi", "3.1"): SpecClasses(
versions.OPENAPIV31: SpecClasses(
V31RequestValidator,
V31ResponseValidator,
V31WebhookRequestValidator,
Expand All @@ -88,7 +90,15 @@


def get_classes(spec: SchemaPath) -> SpecClasses:
return SpecFinder(SPECS).get_classes(spec)
try:
spec_version = get_spec_version(spec.contents())
# backward compatibility
except OpenAPIVersionNotFound:
raise SpecError("Spec schema version not detected")
try:
return SPEC2CLASSES[spec_version]
except KeyError:
raise SpecError("Spec schema version not supported")


def unmarshal_apicall_request(
Expand Down
20 changes: 3 additions & 17 deletions openapi_core/finders.py → openapi_core/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Mapping
from typing import NamedTuple
from typing import Optional
Expand All @@ -21,12 +22,8 @@
from openapi_core.validation.validators import BaseValidator


class SpecVersion(NamedTuple):
name: str
version: str


class SpecClasses(NamedTuple):
@dataclass
class SpecClasses:
request_validator_cls: RequestValidatorType
response_validator_cls: ResponseValidatorType
webhook_request_validator_cls: Optional[WebhookRequestValidatorType]
Expand All @@ -37,14 +34,3 @@ class SpecClasses(NamedTuple):
webhook_response_unmarshaller_cls: Optional[
WebhookResponseUnmarshallerType
]


class SpecFinder:
def __init__(self, specs: Mapping[SpecVersion, SpecClasses]) -> None:
self.specs = specs

def get_classes(self, spec: SchemaPath) -> SpecClasses:
for v, classes in self.specs.items():
if v.name in spec and spec[v.name].startswith(v.version):
return classes
raise SpecError("Spec schema version not detected")
5 changes: 5 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from jsonschema_path import SchemaPath


@pytest.fixture
def spec_v20():
return SchemaPath.from_dict({"swagger": "2.0"})


@pytest.fixture
def spec_v30():
return SchemaPath.from_dict({"openapi": "3.0.0"})
Expand Down
80 changes: 80 additions & 0 deletions tests/unit/test_shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def test_spec_not_detected(self, spec_invalid):
with pytest.raises(SpecError):
unmarshal_apicall_request(request, spec=spec_invalid)

def test_spec_not_supported(self, spec_v20):
request = mock.Mock(spec=Request)

with pytest.raises(SpecError):
unmarshal_apicall_request(request, spec=spec_v20)

def test_request_type_invalid(self, spec_v31):
request = mock.sentinel.request

Expand Down Expand Up @@ -124,6 +130,12 @@ def test_spec_not_detected(self, spec_invalid):
with pytest.raises(SpecError):
unmarshal_webhook_request(request, spec=spec_invalid)

def test_spec_not_supported(self, spec_v20):
request = mock.Mock(spec=WebhookRequest)

with pytest.raises(SpecError):
unmarshal_webhook_request(request, spec=spec_v20)

def test_request_type_invalid(self, spec_v31):
request = mock.sentinel.request

Expand Down Expand Up @@ -169,6 +181,12 @@ def test_spec_not_detected(self, spec_invalid):
with pytest.raises(SpecError):
unmarshal_request(request, spec=spec_invalid)

def test_spec_not_supported(self, spec_v20):
request = mock.Mock(spec=Request)

with pytest.raises(SpecError):
unmarshal_request(request, spec=spec_v20)

def test_request_type_invalid(self, spec_v31):
request = mock.sentinel.request

Expand Down Expand Up @@ -257,6 +275,13 @@ def test_spec_not_detected(self, spec_invalid):
with pytest.raises(SpecError):
unmarshal_apicall_response(request, response, spec=spec_invalid)

def test_spec_not_supported(self, spec_v20):
request = mock.Mock(spec=Request)
response = mock.Mock(spec=Response)

with pytest.raises(SpecError):
unmarshal_apicall_response(request, response, spec=spec_v20)

def test_request_type_invalid(self, spec_v31):
request = mock.sentinel.request
response = mock.Mock(spec=Response)
Expand Down Expand Up @@ -297,6 +322,13 @@ def test_spec_not_detected(self, spec_invalid):
with pytest.raises(SpecError):
unmarshal_response(request, response, spec=spec_invalid)

def test_spec_not_supported(self, spec_v20):
request = mock.Mock(spec=Request)
response = mock.Mock(spec=Response)

with pytest.raises(SpecError):
unmarshal_response(request, response, spec=spec_v20)

def test_request_type_invalid(self, spec_v31):
request = mock.sentinel.request
response = mock.Mock(spec=Response)
Expand Down Expand Up @@ -404,6 +436,13 @@ def test_spec_not_detected(self, spec_invalid):
with pytest.raises(SpecError):
unmarshal_webhook_response(request, response, spec=spec_invalid)

def test_spec_not_supported(self, spec_v20):
request = mock.Mock(spec=WebhookRequest)
response = mock.Mock(spec=Response)

with pytest.raises(SpecError):
unmarshal_webhook_response(request, response, spec=spec_v20)

def test_request_type_invalid(self, spec_v31):
request = mock.sentinel.request
response = mock.Mock(spec=Response)
Expand Down Expand Up @@ -463,6 +502,12 @@ def test_spec_not_detected(self, spec_invalid):
with pytest.raises(SpecError):
validate_apicall_request(request, spec=spec_invalid)

def test_spec_not_supported(self, spec_v20):
request = mock.Mock(spec=Request)

with pytest.raises(SpecError):
validate_apicall_request(request, spec=spec_v20)

def test_request_type_invalid(self, spec_v31):
request = mock.sentinel.request

Expand Down Expand Up @@ -502,6 +547,12 @@ def test_spec_not_detected(self, spec_invalid):
with pytest.raises(SpecError):
validate_webhook_request(request, spec=spec_invalid)

def test_spec_not_supported(self, spec_v20):
request = mock.Mock(spec=WebhookRequest)

with pytest.raises(SpecError):
validate_webhook_request(request, spec=spec_v20)

def test_request_type_invalid(self, spec_v31):
request = mock.sentinel.request

Expand Down Expand Up @@ -548,6 +599,13 @@ def test_spec_not_detected(self, spec_invalid):
with pytest.warns(DeprecationWarning):
validate_request(request, spec=spec_invalid)

def test_spec_not_detected(self, spec_v20):
request = mock.Mock(spec=Request)

with pytest.raises(SpecError):
with pytest.warns(DeprecationWarning):
validate_request(request, spec=spec_v20)

def test_request_type_invalid(self, spec_v31):
request = mock.sentinel.request

Expand Down Expand Up @@ -705,6 +763,13 @@ def test_spec_not_detected(self, spec_invalid):
with pytest.raises(SpecError):
validate_apicall_response(request, response, spec=spec_invalid)

def test_spec_not_supported(self, spec_v20):
request = mock.Mock(spec=Request)
response = mock.Mock(spec=Response)

with pytest.raises(SpecError):
validate_apicall_response(request, response, spec=spec_v20)

def test_request_type_invalid(self, spec_v31):
request = mock.sentinel.request
response = mock.Mock(spec=Response)
Expand Down Expand Up @@ -758,6 +823,13 @@ def test_spec_not_detected(self, spec_invalid):
with pytest.raises(SpecError):
validate_webhook_response(request, response, spec=spec_invalid)

def test_spec_not_supported(self, spec_v20):
request = mock.Mock(spec=WebhookRequest)
response = mock.Mock(spec=Response)

with pytest.raises(SpecError):
validate_webhook_response(request, response, spec=spec_v20)

def test_request_type_invalid(self, spec_v31):
request = mock.sentinel.request
response = mock.Mock(spec=Response)
Expand Down Expand Up @@ -819,6 +891,14 @@ def test_spec_not_detected(self, spec_invalid):
with pytest.warns(DeprecationWarning):
validate_response(request, response, spec=spec_invalid)

def test_spec_not_supported(self, spec_v20):
request = mock.Mock(spec=Request)
response = mock.Mock(spec=Response)

with pytest.raises(SpecError):
with pytest.warns(DeprecationWarning):
validate_response(request, response, spec=spec_v20)

def test_request_type_invalid(self, spec_v31):
request = mock.sentinel.request
response = mock.Mock(spec=Response)
Expand Down

0 comments on commit df1f1e1

Please sign in to comment.