From f9e6a249cda9ed0546b51382627958790272053f Mon Sep 17 00:00:00 2001 From: "T. Franzel" Date: Thu, 11 Nov 2021 11:44:45 +0100 Subject: [PATCH] improve mypy typing #600 --- drf_spectacular/drainage.py | 33 +++++++------ drf_spectacular/generators.py | 2 +- drf_spectacular/openapi.py | 50 +++++++++++--------- drf_spectacular/plumbing.py | 88 ++++++++++++++++++----------------- drf_spectacular/utils.py | 4 +- requirements/docs.txt | 1 + tox.ini | 17 +++++++ 7 files changed, 115 insertions(+), 80 deletions(-) diff --git a/drf_spectacular/drainage.py b/drf_spectacular/drainage.py index d329fe95..39640439 100644 --- a/drf_spectacular/drainage.py +++ b/drf_spectacular/drainage.py @@ -2,14 +2,17 @@ import functools import sys from collections import defaultdict -from typing import DefaultDict +from typing import Any, DefaultDict if sys.version_info >= (3, 8): from typing import Final, Literal, _TypedDictMeta # type: ignore[attr-defined] # noqa: F401 else: - from typing_extensions import ( # type: ignore[attr-defined] # noqa: F401 - Final, Literal, _TypedDictMeta, - ) + from typing_extensions import Final, Literal, _TypedDictMeta # noqa: F401 + +if sys.version_info >= (3, 10): + from typing import TypeGuard # noqa: F401 +else: + from typing_extensions import TypeGuard # noqa: F401 class GeneratorStats: @@ -33,11 +36,11 @@ def silence(self): finally: self.silent = tmp - def reset(self): + def reset(self) -> None: self._warn_cache.clear() self._error_cache.clear() - def emit(self, msg, severity): + def emit(self, msg: str, severity: str) -> None: assert severity in ['warning', 'error'] msg = _get_current_trace() + str(msg) cache = self._warn_cache if severity == 'warning' else self._error_cache @@ -45,7 +48,7 @@ def emit(self, msg, severity): print(f'{severity.capitalize()} #{len(cache)}: {msg}', file=sys.stderr) cache[msg] += 1 - def emit_summary(self): + def emit_summary(self) -> None: if not self.silent and (self._warn_cache or self._error_cache): print( f'\nSchema generation summary:\n' @@ -58,15 +61,15 @@ def emit_summary(self): GENERATOR_STATS = GeneratorStats() -def warn(msg): +def warn(msg: str) -> None: GENERATOR_STATS.emit(msg, 'warning') -def error(msg): +def error(msg: str) -> None: GENERATOR_STATS.emit(msg, 'error') -def reset_generator_stats(): +def reset_generator_stats() -> None: GENERATOR_STATS.reset() @@ -74,7 +77,7 @@ def reset_generator_stats(): @contextlib.contextmanager -def add_trace_message(trace_message): +def add_trace_message(trace_message: str): """ Adds a message to be used as a prefix when emitting warnings and errors. """ @@ -83,11 +86,11 @@ def add_trace_message(trace_message): _TRACES.pop() -def _get_current_trace(): +def _get_current_trace() -> str: return ''.join(f"{trace}: " for trace in _TRACES if trace) -def has_override(obj, prop): +def has_override(obj, prop: str) -> bool: if isinstance(obj, functools.partial): obj = obj.func if not hasattr(obj, '_spectacular_annotation'): @@ -97,7 +100,7 @@ def has_override(obj, prop): return True -def get_override(obj, prop, default=None): +def get_override(obj, prop: str, default: Any = None) -> Any: if isinstance(obj, functools.partial): obj = obj.func if not has_override(obj, prop): @@ -105,7 +108,7 @@ def get_override(obj, prop, default=None): return obj._spectacular_annotation[prop] -def set_override(obj, prop, value): +def set_override(obj, prop: str, value: Any): if not hasattr(obj, '_spectacular_annotation'): obj._spectacular_annotation = {} elif '_spectacular_annotation' not in obj.__dict__: diff --git a/drf_spectacular/generators.py b/drf_spectacular/generators.py index 8850a933..051cd252 100644 --- a/drf_spectacular/generators.py +++ b/drf_spectacular/generators.py @@ -3,7 +3,7 @@ from django.urls import URLPattern, URLResolver from rest_framework import views, viewsets -from rest_framework.schemas.generators import BaseSchemaGenerator # type: ignore +from rest_framework.schemas.generators import BaseSchemaGenerator from rest_framework.schemas.generators import EndpointEnumerator as BaseEndpointEnumerator from rest_framework.settings import api_settings diff --git a/drf_spectacular/openapi.py b/drf_spectacular/openapi.py index afd9e10d..22a615f7 100644 --- a/drf_spectacular/openapi.py +++ b/drf_spectacular/openapi.py @@ -1,7 +1,7 @@ import copy import re -import typing from collections import defaultdict +from typing import Any, Dict, List, Optional, Union import uritemplate from django.core import exceptions as django_exceptions @@ -13,7 +13,7 @@ from rest_framework.generics import CreateAPIView, GenericAPIView, ListCreateAPIView from rest_framework.mixins import ListModelMixin from rest_framework.schemas.inspectors import ViewInspector -from rest_framework.schemas.utils import get_pk_description # type: ignore +from rest_framework.schemas.utils import get_pk_description from rest_framework.settings import api_settings from rest_framework.utils.model_meta import get_field_info from rest_framework.views import APIView @@ -36,7 +36,9 @@ ) from drf_spectacular.settings import spectacular_settings from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import OpenApiParameter, OpenApiResponse +from drf_spectacular.utils import ( + Direction, OpenApiParameter, OpenApiResponse, _SchemaType, _SerializerType, +) class AutoSchema(ViewInspector): @@ -48,14 +50,14 @@ class AutoSchema(ViewInspector): 'delete': 'destroy', } - def get_operation(self, path, path_regex, path_prefix, method, registry: ComponentRegistry): + def get_operation(self, path, path_regex, path_prefix, method, registry: ComponentRegistry) -> _SchemaType: self.registry = registry self.path = path self.path_regex = path_regex self.path_prefix = path_prefix self.method = method.upper() - operation = {'operationId': self.get_operation_id()} + operation: _SchemaType = {'operationId': self.get_operation_id()} description = self.get_description() if description: @@ -135,7 +137,7 @@ def _is_create_operation(self): return True return False - def get_override_parameters(self): + def get_override_parameters(self) -> List[Union[OpenApiParameter, _SerializerType]]: """ override this for custom behaviour """ return [] @@ -189,13 +191,13 @@ def _process_override_parameters(self): warn(f'could not resolve parameter annotation {parameter}. Skipping.') return result - def _get_format_parameters(self): + def _get_format_parameters(self) -> List[dict]: parameters = [] formats = self.map_renderers('format') if api_settings.URL_FORMAT_OVERRIDE and len(formats) > 1: parameters.append(build_parameter_type( name=api_settings.URL_FORMAT_OVERRIDE, - schema=build_basic_type(OpenApiTypes.STR), + schema=build_basic_type(OpenApiTypes.STR), # type: ignore location=OpenApiParameter.QUERY, enum=formats )) @@ -243,14 +245,14 @@ def dict_helper(parameters): else: return list(parameters.values()) - def get_description(self): + def get_description(self) -> str: # type: ignore """ override this for custom behaviour """ action_or_method = getattr(self.view, getattr(self.view, 'action', self.method.lower()), None) view_doc = get_doc(self.view.__class__) action_doc = get_doc(action_or_method) return action_doc or view_doc - def get_summary(self): + def get_summary(self) -> Optional[str]: """ override this for custom behaviour """ return None @@ -305,24 +307,24 @@ def get_auth(self): auths.append({}) return auths - def get_request_serializer(self) -> typing.Any: + def get_request_serializer(self) -> Any: """ override this for custom behaviour """ return self._get_serializer() - def get_response_serializers(self) -> typing.Any: + def get_response_serializers(self) -> Any: """ override this for custom behaviour """ return self._get_serializer() - def get_tags(self) -> typing.List[str]: + def get_tags(self) -> List[str]: """ override this for custom behaviour """ tokenized_path = self._tokenize_path() # use first non-parameter path part as tag return tokenized_path[:1] - def get_extensions(self) -> typing.Dict[str, typing.Any]: + def get_extensions(self) -> Dict[str, Any]: return {} - def get_operation_id(self): + def get_operation_id(self) -> str: """ override this for custom behaviour """ tokenized_path = self._tokenize_path() # replace dashes as they can be problematic later in code generation @@ -341,11 +343,11 @@ def get_operation_id(self): return '_'.join(tokenized_path + [action]) - def is_deprecated(self): + def is_deprecated(self) -> bool: """ override this for custom behaviour """ return False - def _tokenize_path(self): + def _tokenize_path(self) -> List[str]: # remove path prefix path = re.sub( pattern=self.path_prefix, @@ -1168,7 +1170,7 @@ def _get_response_bodies(self): schema['description'] = _('Unspecified response body') return {'200': self._get_response_for_code(schema, '200')} - def _unwrap_list_serializer(self, serializer, direction) -> typing.Optional[dict]: + def _unwrap_list_serializer(self, serializer, direction) -> Optional[dict]: if is_field(serializer): return self._map_serializer_field(serializer, direction) elif is_basic_serializer(serializer): @@ -1279,7 +1281,11 @@ def _get_response_headers_for_code(self, status_code) -> dict: elif is_serializer(parameter.type): schema = self.resolve_serializer(parameter.type, 'response').ref else: - schema = parameter.type + schema = parameter.type # type: ignore + + if not schema: + warn(f'response parameter {parameter.name} requires non-empty schema') + continue if parameter.location not in [OpenApiParameter.HEADER, OpenApiParameter.COOKIE]: warn(f'incompatible location type ignored for response parameter {parameter.name}') @@ -1305,7 +1311,7 @@ def _get_response_headers_for_code(self, status_code) -> dict: return result - def _get_serializer_name(self, serializer, direction): + def _get_serializer_name(self, serializer, direction: Direction) -> str: serializer_extension = OpenApiSerializerExtension.get_match(serializer) if serializer_extension and serializer_extension.get_name(): # library override mechanisms @@ -1322,6 +1328,8 @@ def _get_serializer_name(self, serializer, direction): else: name = serializer.__class__.__name__ + assert name + if name.endswith('Serializer'): name = name[:-10] @@ -1333,7 +1341,7 @@ def _get_serializer_name(self, serializer, direction): return name - def resolve_serializer(self, serializer, direction) -> ResolvedComponent: + def resolve_serializer(self, serializer: _SerializerType, direction: Direction) -> ResolvedComponent: assert_basic_serializer(serializer) serializer = force_instance(serializer) diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index eab0e187..3a6143c9 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -12,7 +12,7 @@ from collections import OrderedDict, defaultdict from decimal import Decimal from enum import Enum -from typing import Any, DefaultDict, Generic, List, Optional, Tuple, Type, TypeVar, Union +from typing import Any, DefaultDict, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union import inflection import uritemplate @@ -39,12 +39,16 @@ from rest_framework.utils.mediatypes import _MediaType from uritemplate import URITemplate -from drf_spectacular.drainage import Literal, _TypedDictMeta, cache, error, warn +from drf_spectacular.drainage import Literal, TypeGuard, _TypedDictMeta, cache, error, warn from drf_spectacular.settings import spectacular_settings from drf_spectacular.types import ( DJANGO_PATH_CONVERTER_MAPPING, OPENAPI_TYPE_MAPPING, PYTHON_TYPE_MAPPING, OpenApiTypes, + _KnownPythonTypes, +) +from drf_spectacular.utils import ( + OpenApiExample, OpenApiParameter, _FieldType, _ListSerializerType, _ParameterLocationType, + _SchemaType, _SerializerType, ) -from drf_spectacular.utils import OpenApiParameter try: from django.db.models.enums import Choices # only available in Django>3 @@ -54,14 +58,14 @@ class Choices: # type: ignore # types.UnionType was added in Python 3.10 for new PEP 604 pipe union syntax if hasattr(types, 'UnionType'): - UNION_TYPES: Tuple[Any, ...] = (typing.Union, types.UnionType) # type: ignore + UNION_TYPES: Tuple[Any, ...] = (Union, types.UnionType) # type: ignore else: - UNION_TYPES = (typing.Union,) + UNION_TYPES = (Union,) if sys.version_info >= (3, 8): - CACHED_PROPERTY_FUNCS = (functools.cached_property, cached_property) # type: ignore + CACHED_PROPERTY_FUNCS = (functools.cached_property, cached_property) else: - CACHED_PROPERTY_FUNCS = (cached_property,) # type: ignore + CACHED_PROPERTY_FUNCS = (cached_property,) T = TypeVar('T') @@ -83,7 +87,7 @@ def force_instance(serializer_or_field): return serializer_or_field -def is_serializer(obj) -> bool: +def is_serializer(obj) -> TypeGuard[_SerializerType]: from drf_spectacular.serializers import OpenApiSerializerExtension return ( isinstance(force_instance(obj), serializers.BaseSerializer) @@ -91,21 +95,21 @@ def is_serializer(obj) -> bool: ) -def is_list_serializer(obj) -> bool: +def is_list_serializer(obj) -> TypeGuard[_ListSerializerType]: return isinstance(force_instance(obj), serializers.ListSerializer) -def is_basic_serializer(obj) -> bool: +def is_basic_serializer(obj) -> TypeGuard[_SerializerType]: return is_serializer(obj) and not is_list_serializer(obj) -def is_field(obj): +def is_field(obj) -> TypeGuard[_FieldType]: # make sure obj is a serializer field and nothing else. # guard against serializers because BaseSerializer(Field) return isinstance(force_instance(obj), fields.Field) and not is_serializer(obj) -def is_basic_type(obj, allow_none=True): +def is_basic_type(obj, allow_none=True) -> TypeGuard[_KnownPythonTypes]: if not isinstance(obj, collections.abc.Hashable): return False if not allow_none and (obj is None or obj is OpenApiTypes.NONE): @@ -113,7 +117,7 @@ def is_basic_type(obj, allow_none=True): return obj in get_openapi_type_mapping() or obj in PYTHON_TYPE_MAPPING -def is_patched_serializer(serializer, direction): +def is_patched_serializer(serializer, direction) -> bool: return bool( spectacular_settings.COMPONENT_SPLIT_PATCH and serializer.partial @@ -122,13 +126,13 @@ def is_patched_serializer(serializer, direction): ) -def is_trivial_string_variation(a: str, b: str): +def is_trivial_string_variation(a: str, b: str) -> bool: a = (a or '').strip().lower().replace(' ', '_').replace('-', '_') b = (b or '').strip().lower().replace(' ', '_').replace('-', '_') return a == b -def assert_basic_serializer(serializer): +def assert_basic_serializer(serializer) -> None: assert is_basic_serializer(serializer), ( f'internal assumption violated because we expected a basic serializer here and ' f'instead got a "{serializer}". This may be the result of another app doing ' @@ -174,7 +178,7 @@ def get_view_model(view, emit_warnings=True): ) -def get_doc(obj): +def get_doc(obj) -> str: """ get doc string with fallback on obj's base classes (ignoring DRF documentation). """ def post_cleanup(doc: str): # also clean up trailing whitespace for each line @@ -198,7 +202,7 @@ def safe_index(lst, item): return '' -def get_type_hints(obj): +def get_type_hints(obj) -> Dict[str, Any]: """ unpack wrapped partial object and use actual func object """ if isinstance(obj, functools.partial): obj = obj.func @@ -213,7 +217,7 @@ def get_openapi_type_mapping(): } -def build_generic_type(): +def build_generic_type() -> dict: if spectacular_settings.GENERIC_ADDITIONAL_PROPERTIES is None: return {'type': 'object'} elif spectacular_settings.GENERIC_ADDITIONAL_PROPERTIES == 'bool': @@ -222,7 +226,7 @@ def build_generic_type(): return {'type': 'object', 'additionalProperties': {}} -def build_basic_type(obj): +def build_basic_type(obj: Union[_KnownPythonTypes, OpenApiTypes]) -> Optional[_SchemaType]: """ resolve either enum or actual type and yield schema template for modification """ @@ -238,7 +242,7 @@ def build_basic_type(obj): return dict(openapi_type_mapping[OpenApiTypes.STR]) -def build_array_type(schema, min_length=None, max_length=None): +def build_array_type(schema, min_length=None, max_length=None) -> _SchemaType: schema = {'type': 'array', 'items': schema} if min_length is not None: schema['minLength'] = min_length @@ -248,12 +252,12 @@ def build_array_type(schema, min_length=None, max_length=None): def build_object_type( - properties=None, + properties: Optional[List[dict]] = None, required=None, - description=None, + description: Optional[str] = None, **kwargs -): - schema = {'type': 'object'} +) -> _SchemaType: + schema: _SchemaType = {'type': 'object'} if description: schema['description'] = description.strip() if properties: @@ -266,14 +270,14 @@ def build_object_type( return schema -def build_media_type_object(schema, examples=None): +def build_media_type_object(schema: _SchemaType, examples=None) -> _SchemaType: media_type_object = {'schema': schema} if examples: media_type_object['examples'] = examples return media_type_object -def build_examples_list(examples): +def build_examples_list(examples: List[OpenApiExample]) -> _SchemaType: schema = {} for example in examples: normalized_name = inflection.camelize(example.name.replace(' ', '_')) @@ -293,9 +297,9 @@ def build_examples_list(examples): def build_parameter_type( - name, - schema, - location, + name: str, + schema: _SchemaType, + location: _ParameterLocationType, required=False, description=None, enum=None, @@ -306,7 +310,7 @@ def build_parameter_type( allow_blank=True, examples=None, extensions=None, -): +) -> _SchemaType: irrelevant_field_meta = ['readOnly', 'writeOnly'] if location == OpenApiParameter.PATH: irrelevant_field_meta += ['nullable', 'default'] @@ -338,11 +342,11 @@ def build_parameter_type( return schema -def build_choice_field(field): +def build_choice_field(field) -> _SchemaType: choices = list(OrderedDict.fromkeys(field.choices)) # preserve order and remove duplicates if all(isinstance(choice, bool) for choice in choices): - type = 'boolean' + type: Optional[str] = 'boolean' elif all(isinstance(choice, int) for choice in choices): type = 'integer' elif all(isinstance(choice, (int, float, Decimal)) for choice in choices): # `number` includes `integer` @@ -358,7 +362,7 @@ def build_choice_field(field): if field.allow_null: choices.append(None) - schema = { + schema: _SchemaType = { # The value of `enum` keyword MUST be an array and SHOULD be unique. # Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.20 'enum': choices @@ -373,7 +377,7 @@ def build_choice_field(field): return schema -def build_bearer_security_scheme_object(header_name, token_prefix, bearer_format=None): +def build_bearer_security_scheme_object(header_name, token_prefix, bearer_format=None) -> _SchemaType: """ Either build a bearer scheme or a fallback due to OpenAPI 3.0.3 limitations """ # normalize Django header quirks if header_name.startswith('HTTP_'): @@ -397,7 +401,7 @@ def build_bearer_security_scheme_object(header_name, token_prefix, bearer_format } -def build_root_object(paths, components, version): +def build_root_object(paths, components, version) -> _SchemaType: settings = spectacular_settings if settings.VERSION and version: version = f'{settings.VERSION} ({version})' @@ -431,7 +435,7 @@ def build_root_object(paths, components, version): return root -def safe_ref(schema): +def safe_ref(schema: _SchemaType) -> _SchemaType: """ ensure that $ref has its own context and does not remove potential sibling entries when $ref is substituted. @@ -441,7 +445,7 @@ def safe_ref(schema): return schema -def append_meta(schema, meta): +def append_meta(schema: _SchemaType, meta: _SchemaType) -> _SchemaType: return safe_ref({**schema, **meta}) @@ -752,7 +756,7 @@ def load_enum_name_overrides(): return overrides -def list_hash(lst): +def list_hash(lst: List[Any]) -> str: return hashlib.sha256(json.dumps(list(lst), sort_keys=True).encode()).hexdigest() @@ -855,7 +859,7 @@ def resolve_regex_path_parameter(path_regex, variable): return None -def is_versioning_supported(versioning_class): +def is_versioning_supported(versioning_class) -> bool: return issubclass(versioning_class, ( versioning.URLPathVersioning, versioning.NamespaceVersioning, @@ -863,7 +867,7 @@ def is_versioning_supported(versioning_class): )) -def operation_matches_version(view, requested_version): +def operation_matches_version(view, requested_version) -> bool: try: version, _ = view.determine_version(view.request, **view.kwargs) except exceptions.NotAcceptable: @@ -929,7 +933,7 @@ def modify_media_types_for_versioning(view, media_types: List[str]) -> List[str] ] -def analyze_named_regex_pattern(path): +def analyze_named_regex_pattern(path: str) -> Dict[str, str]: """ safely extract named groups and their pattern from given regex pattern """ result = {} stack = 0 @@ -1192,7 +1196,7 @@ def resolve_type_hint(hint): raise UnableToProceedError() -def whitelisted(obj: object, classes: List[Type[object]], exact=False): +def whitelisted(obj: object, classes: List[Type[object]], exact=False) -> bool: if not classes: return True if exact: diff --git a/drf_spectacular/utils.py b/drf_spectacular/utils.py index a4db96ae..ecd000de 100644 --- a/drf_spectacular/utils.py +++ b/drf_spectacular/utils.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union from rest_framework.fields import Field, empty -from rest_framework.serializers import Serializer +from rest_framework.serializers import ListSerializer, Serializer from rest_framework.settings import api_settings from drf_spectacular.drainage import ( @@ -10,9 +10,11 @@ ) from drf_spectacular.types import OpenApiTypes, _KnownPythonTypes +_ListSerializerType = Union[ListSerializer, Type[ListSerializer]] _SerializerType = Union[Serializer, Type[Serializer]] _FieldType = Union[Field, Type[Field]] _ParameterLocationType = Literal['query', 'path', 'header', 'cookie'] +_SchemaType = Dict[str, Any] Direction = Literal['request', 'response'] diff --git a/requirements/docs.txt b/requirements/docs.txt index 486e4406..ca42c48f 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,2 +1,3 @@ Sphinx>=4.1.0 sphinx_rtd_theme>=0.5.1 +typing-extensions diff --git a/tox.ini b/tox.ini index 2d41e6fe..4cb27551 100644 --- a/tox.ini +++ b/tox.ini @@ -86,10 +86,27 @@ include_trailing_comma = true [mypy] python_version = 3.8 plugins = mypy_django_plugin.main,mypy_drf_plugin.main +warn_unused_configs = True +warn_redundant_casts = True +warn_unused_ignores = True [mypy.plugins.django-stubs] django_settings_module = "tests.settings" +[mypy-drf_spectacular.*] +strict_equality = True +no_implicit_optional = True +disallow_untyped_decorators = True +disallow_subclassing_any = True +;check_untyped_defs = True +;warn_return_any = True +;no_implicit_reexport = True +;disallow_incomplete_defs = True +;disallow_any_generics = True +;disallow_untyped_calls = True +;disallow_untyped_defs = True + + [mypy-rest_framework.compat.*] ignore_missing_imports = True