Skip to content

Commit

Permalink
Merge pull request #677 from python-openapi/refactor/parameter-and-he…
Browse files Browse the repository at this point in the history
…ader-get-value-refactor

Parameter and header get value refactor
  • Loading branch information
p1c2u authored Sep 23, 2023
2 parents 0da2a38 + 890ae99 commit 7a17349
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 93 deletions.
32 changes: 0 additions & 32 deletions openapi_core/schema/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,38 +45,6 @@ def get_explode(param_or_header: Spec) -> bool:
return style == "form"


def get_value(
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
"""Returns parameter/header value from specific location"""
name = name or param_or_header["name"]
style = get_style(param_or_header)

if name not in location:
# Only check if the name is not in the location if the style of
# the param is deepObject,this is because deepObjects will never be found
# as their key also includes the properties of the object already.
if style != "deepObject":
raise KeyError
keys_str = " ".join(location.keys())
if not re.search(rf"{name}\[\w+\]", keys_str):
raise KeyError

aslist = get_aslist(param_or_header)
explode = get_explode(param_or_header)
if aslist and explode:
if style == "deepObject":
return get_deep_object_value(location, name)
if isinstance(location, SuportsGetAll):
return location.getall(name)
if isinstance(location, SuportsGetList):
return location.getlist(name)

return location[name]


def get_deep_object_value(
location: Mapping[str, Any],
name: Optional[str] = None,
Expand Down
4 changes: 4 additions & 0 deletions openapi_core/templating/media_types/finders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ class MediaTypeFinder:
def __init__(self, content: Spec):
self.content = content

def get_first(self) -> MediaType:
mimetype, media_type = next(self.content.items())
return MediaType(media_type, mimetype)

def find(self, mimetype: str) -> MediaType:
if mimetype in self.content:
return MediaType(self.content / mimetype, mimetype)
Expand Down
6 changes: 3 additions & 3 deletions openapi_core/unmarshalling/response/unmarshallers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _unmarshal(
operation: Spec,
) -> ResponseUnmarshalResult:
try:
operation_response = self._get_operation_response(
operation_response = self._find_operation_response(
response.status_code, operation
)
# don't process if operation errors
Expand Down Expand Up @@ -96,7 +96,7 @@ def _unmarshal_data(
operation: Spec,
) -> ResponseUnmarshalResult:
try:
operation_response = self._get_operation_response(
operation_response = self._find_operation_response(
response.status_code, operation
)
# don't process if operation errors
Expand Down Expand Up @@ -124,7 +124,7 @@ def _unmarshal_headers(
operation: Spec,
) -> ResponseUnmarshalResult:
try:
operation_response = self._get_operation_response(
operation_response = self._find_operation_response(
response.status_code, operation
)
# don't process if operation errors
Expand Down
17 changes: 8 additions & 9 deletions openapi_core/unmarshalling/unmarshallers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,24 +87,23 @@ def _unmarshal_schema(self, schema: Spec, value: Any) -> Any:
)
return unmarshaller.unmarshal(value)

def _get_param_or_header_value(
def _convert_schema_style_value(
self,
raw: Any,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
casted, schema = self._get_param_or_header_value_and_schema(
param_or_header, location, name
casted, schema = self._convert_schema_style_value_and_schema(
raw, param_or_header
)
if schema is None:
return casted
return self._unmarshal_schema(schema, casted)

def _get_content_value(
self, raw: Any, mimetype: str, content: Spec
def _convert_content_schema_value(
self, raw: Any, content: Spec, mimetype: Optional[str] = None
) -> Any:
casted, schema = self._get_content_value_and_schema(
raw, mimetype, content
casted, schema = self._convert_content_schema_value_and_schema(
raw, content, mimetype
)
if schema is None:
return casted
Expand Down
5 changes: 3 additions & 2 deletions openapi_core/validation/request/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,9 @@ def _get_parameter(

param_location = param["in"]
location = parameters[param_location]

try:
return self._get_param_or_header_value(param, location)
return self._get_param_or_header(param, location, name=name)
except KeyError:
required = param.getkey("required", False)
if required:
Expand Down Expand Up @@ -248,7 +249,7 @@ def _get_body(
content = request_body / "content"

raw_body = self._get_body_value(body, request_body)
return self._get_content_value(raw_body, mimetype, content)
return self._convert_content_schema_value(raw_body, content, mimetype)

def _get_body_value(self, body: Optional[str], request_body: Spec) -> Any:
if not body:
Expand Down
12 changes: 6 additions & 6 deletions openapi_core/validation/response/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _iter_errors(
operation: Spec,
) -> Iterator[Exception]:
try:
operation_response = self._get_operation_response(
operation_response = self._find_operation_response(
status_code, operation
)
# don't process if operation errors
Expand All @@ -64,7 +64,7 @@ def _iter_data_errors(
self, status_code: int, data: str, mimetype: str, operation: Spec
) -> Iterator[Exception]:
try:
operation_response = self._get_operation_response(
operation_response = self._find_operation_response(
status_code, operation
)
# don't process if operation errors
Expand All @@ -81,7 +81,7 @@ def _iter_headers_errors(
self, status_code: int, headers: Mapping[str, Any], operation: Spec
) -> Iterator[Exception]:
try:
operation_response = self._get_operation_response(
operation_response = self._find_operation_response(
status_code, operation
)
# don't process if operation errors
Expand All @@ -94,7 +94,7 @@ def _iter_headers_errors(
except HeadersError as exc:
yield from exc.context

def _get_operation_response(
def _find_operation_response(
self,
status_code: int,
operation: Spec,
Expand All @@ -114,7 +114,7 @@ def _get_data(
content = operation_response / "content"

raw_data = self._get_data_value(data)
return self._get_content_value(raw_data, mimetype, content)
return self._convert_content_schema_value(raw_data, content, mimetype)

def _get_data_value(self, data: str) -> Any:
if not data:
Expand Down Expand Up @@ -163,7 +163,7 @@ def _get_header(
)

try:
return self._get_param_or_header_value(header, headers, name=name)
return self._get_param_or_header(header, headers, name=name)
except KeyError:
required = header.getkey("required", False)
if required:
Expand Down
155 changes: 114 additions & 41 deletions openapi_core/validation/validators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""OpenAPI core validation validators module"""
import re
from functools import cached_property
from typing import Any
from typing import Mapping
Expand All @@ -23,7 +24,12 @@
)
from openapi_core.protocols import Request
from openapi_core.protocols import WebhookRequest
from openapi_core.schema.parameters import get_value
from openapi_core.schema.parameters import get_aslist
from openapi_core.schema.parameters import get_deep_object_value
from openapi_core.schema.parameters import get_explode
from openapi_core.schema.parameters import get_style
from openapi_core.schema.protocols import SuportsGetAll
from openapi_core.schema.protocols import SuportsGetList
from openapi_core.spec import Spec
from openapi_core.templating.media_types.datatypes import MediaType
from openapi_core.templating.paths.datatypes import PathOperationServer
Expand Down Expand Up @@ -70,10 +76,14 @@ def __init__(
self.extra_format_validators = extra_format_validators
self.extra_media_type_deserializers = extra_media_type_deserializers

def _get_media_type(self, content: Spec, mimetype: str) -> MediaType:
def _find_media_type(
self, content: Spec, mimetype: Optional[str] = None
) -> MediaType:
from openapi_core.templating.media_types.finders import MediaTypeFinder

finder = MediaTypeFinder(content)
if mimetype is None:
return finder.get_first()
return finder.find(mimetype)

def _deserialise_media_type(self, mimetype: str, value: Any) -> Any:
Expand All @@ -99,69 +109,93 @@ def _validate_schema(self, schema: Spec, value: Any) -> None:
)
validator.validate(value)

def _get_param_or_header_value(
def _get_param_or_header(
self,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
casted, schema = self._get_param_or_header_value_and_schema(
param_or_header, location, name
# Simple scenario
if "content" not in param_or_header:
return self._get_simple_param_or_header(
param_or_header, location, name=name
)

# Complex scenario
return self._get_complex_param_or_header(
param_or_header, location, name=name
)

def _get_simple_param_or_header(
self,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
try:
raw = self._get_style_value(param_or_header, location, name=name)
except KeyError:
# in simple scenrios schema always exist
schema = param_or_header / "schema"
if "default" not in schema:
raise
raw = schema["default"]
return self._convert_schema_style_value(raw, param_or_header)

def _get_complex_param_or_header(
self,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
content = param_or_header / "content"
# no point to catch KetError
# in complex scenrios schema doesn't exist
raw = self._get_media_type_value(param_or_header, location, name=name)
return self._convert_content_schema_value(raw, content)

def _convert_schema_style_value(
self,
raw: Any,
param_or_header: Spec,
) -> Any:
casted, schema = self._convert_schema_style_value_and_schema(
raw, param_or_header
)
if schema is None:
return casted
self._validate_schema(schema, casted)
return casted

def _get_content_value(
self, raw: Any, mimetype: str, content: Spec
def _convert_content_schema_value(
self, raw: Any, content: Spec, mimetype: Optional[str] = None
) -> Any:
casted, schema = self._get_content_value_and_schema(
raw, mimetype, content
casted, schema = self._convert_content_schema_value_and_schema(
raw, content, mimetype
)
if schema is None:
return casted
self._validate_schema(schema, casted)
return casted

def _get_param_or_header_value_and_schema(
def _convert_schema_style_value_and_schema(
self,
raw: Any,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Tuple[Any, Spec]:
try:
raw_value = get_value(param_or_header, location, name=name)
except KeyError:
if "schema" not in param_or_header:
raise
schema = param_or_header / "schema"
if "default" not in schema:
raise
casted = schema["default"]
else:
# Simple scenario
if "content" not in param_or_header:
deserialised = self._deserialise_style(
param_or_header, raw_value
)
schema = param_or_header / "schema"
# Complex scenario
else:
content = param_or_header / "content"
mimetype, media_type = next(content.items())
deserialised = self._deserialise_media_type(
mimetype, raw_value
)
schema = media_type / "schema"
casted = self._cast(schema, deserialised)
deserialised = self._deserialise_style(param_or_header, raw)
schema = param_or_header / "schema"
casted = self._cast(schema, deserialised)
return casted, schema

def _get_content_value_and_schema(
self, raw: Any, mimetype: str, content: Spec
def _convert_content_schema_value_and_schema(
self,
raw: Any,
content: Spec,
mimetype: Optional[str] = None,
) -> Tuple[Any, Optional[Spec]]:
media_type, mimetype = self._get_media_type(content, mimetype)
deserialised = self._deserialise_media_type(mimetype, raw)
media_type, mime_type = self._find_media_type(content, mimetype)
deserialised = self._deserialise_media_type(mime_type, raw)
casted = self._cast(media_type, deserialised)

if "schema" not in media_type:
Expand All @@ -170,6 +204,45 @@ def _get_content_value_and_schema(
schema = media_type / "schema"
return casted, schema

def _get_style_value(
self,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
name = name or param_or_header["name"]
style = get_style(param_or_header)
if name not in location:
# Only check if the name is not in the location if the style of
# the param is deepObject,this is because deepObjects will never be found
# as their key also includes the properties of the object already.
if style != "deepObject":
raise KeyError
keys_str = " ".join(location.keys())
if not re.search(rf"{name}\[\w+\]", keys_str):
raise KeyError

aslist = get_aslist(param_or_header)
explode = get_explode(param_or_header)
if aslist and explode:
if style == "deepObject":
return get_deep_object_value(location, name)
if isinstance(location, SuportsGetAll):
return location.getall(name)
if isinstance(location, SuportsGetList):
return location.getlist(name)

return location[name]

def _get_media_type_value(
self,
param_or_header: Spec,
location: Mapping[str, Any],
name: Optional[str] = None,
) -> Any:
name = name or param_or_header["name"]
return location[name]


class BaseAPICallValidator(BaseValidator):
@cached_property
Expand Down
Loading

0 comments on commit 7a17349

Please sign in to comment.