From f6666202e11a8e05795312293304c9db0b29e17e Mon Sep 17 00:00:00 2001 From: "T. Franzel" Date: Thu, 11 Nov 2021 11:44:45 +0100 Subject: [PATCH] improve mypy typing #600 --- drf_spectacular/contrib/rest_polymorphic.py | 3 +- drf_spectacular/drainage.py | 46 ++++--- drf_spectacular/extensions.py | 21 +-- drf_spectacular/generators.py | 12 +- drf_spectacular/hooks.py | 3 +- drf_spectacular/openapi.py | 114 +++++++++------- drf_spectacular/plumbing.py | 127 ++++++++++-------- drf_spectacular/utils.py | 10 +- requirements/base.txt | 2 +- requirements/docs.txt | 1 + tests/contrib/test_drf_spectacular_sidecar.py | 2 +- tests/contrib/test_pydantic.py | 2 +- tests/test_fields.py | 4 +- tox.ini | 19 ++- 14 files changed, 220 insertions(+), 146 deletions(-) diff --git a/drf_spectacular/contrib/rest_polymorphic.py b/drf_spectacular/contrib/rest_polymorphic.py index 431fb974..943e0799 100644 --- a/drf_spectacular/contrib/rest_polymorphic.py +++ b/drf_spectacular/contrib/rest_polymorphic.py @@ -1,6 +1,7 @@ +from drf_spectacular.drainage import warn from drf_spectacular.extensions import OpenApiSerializerExtension from drf_spectacular.plumbing import ( - ResolvedComponent, build_basic_type, build_object_type, is_patched_serializer, warn, + ResolvedComponent, build_basic_type, build_object_type, is_patched_serializer, ) from drf_spectacular.settings import spectacular_settings from drf_spectacular.types import OpenApiTypes diff --git a/drf_spectacular/drainage.py b/drf_spectacular/drainage.py index 6b9de36d..9b47dc3e 100644 --- a/drf_spectacular/drainage.py +++ b/drf_spectacular/drainage.py @@ -3,7 +3,21 @@ import inspect import sys from collections import defaultdict -from typing import DefaultDict, List, Optional, Tuple +from typing import Any, Callable, DefaultDict, List, Optional, Tuple, TypeVar + +if sys.version_info >= (3, 8): + from typing import ( # type: ignore[attr-defined] # noqa: F401 + Final, Literal, TypedDict, _TypedDictMeta, + ) +else: + from typing_extensions import Final, Literal, TypedDict, _TypedDictMeta # noqa: F401 + +if sys.version_info >= (3, 10): + from typing import TypeGuard # noqa: F401 +else: + from typing_extensions import TypeGuard # noqa: F401 + +F = TypeVar('F', bound=Callable[..., Any]) class GeneratorStats: @@ -37,20 +51,20 @@ def silence(self): finally: self.silent = tmp - def reset(self): + def reset(self) -> None: self._warn_cache.clear() self._error_cache.clear() - def enable_color(self): + def enable_color(self) -> None: self._blue = '\033[0;34m' self._red = '\033[0;31m' self._yellow = '\033[0;33m' self._clear = '\033[0m' - def enable_trace_lineno(self): + def enable_trace_lineno(self) -> None: self._trace_lineno = True - def _get_current_trace(self): + def _get_current_trace(self) -> Tuple[Optional[str], str]: source_locations = [t for t in self._traces if t[0]] if source_locations: sourcefile, lineno, _ = source_locations[-1] @@ -60,7 +74,7 @@ def _get_current_trace(self): breadcrumbs = ' > '.join(t[2] for t in self._traces) return source_location, breadcrumbs - def emit(self, msg, severity): + def emit(self, msg: str, severity: str) -> None: assert severity in ['warning', 'error'] cache = self._warn_cache if severity == 'warning' else self._error_cache @@ -75,7 +89,7 @@ def emit(self, msg, severity): print(msg, file=sys.stderr) cache[msg] += 1 - def emit_summary(self): + def emit_summary(self) -> None: if not self.silent and (self._warn_cache or self._error_cache): print( f'\nSchema generation summary:\n' @@ -88,7 +102,7 @@ def emit_summary(self): GENERATOR_STATS = GeneratorStats() -def warn(msg, delayed=None): +def warn(msg: str, delayed: Any = None) -> None: if delayed: warnings = get_override(delayed, 'warnings', []) warnings.append(msg) @@ -97,7 +111,7 @@ def warn(msg, delayed=None): GENERATOR_STATS.emit(msg, 'warning') -def error(msg, delayed=None): +def error(msg: str, delayed: Any = None) -> None: if delayed: errors = get_override(delayed, 'errors', []) errors.append(msg) @@ -106,7 +120,7 @@ def error(msg, delayed=None): GENERATOR_STATS.emit(msg, 'error') -def reset_generator_stats(): +def reset_generator_stats() -> None: GENERATOR_STATS.reset() @@ -136,7 +150,7 @@ def _get_source_location(obj): return sourcefile, lineno -def has_override(obj, prop): +def has_override(obj: Any, prop: str) -> bool: if isinstance(obj, functools.partial): obj = obj.func if not hasattr(obj, '_spectacular_annotation'): @@ -146,7 +160,7 @@ def has_override(obj, prop): return True -def get_override(obj, prop, default=None): +def get_override(obj: Any, prop: str, default: Any = None) -> Any: if isinstance(obj, functools.partial): obj = obj.func if not has_override(obj, prop): @@ -154,7 +168,7 @@ def get_override(obj, prop, default=None): return obj._spectacular_annotation[prop] -def set_override(obj, prop, value): +def set_override(obj: Any, prop: str, value: Any) -> Any: if not hasattr(obj, '_spectacular_annotation'): obj._spectacular_annotation = {} elif '_spectacular_annotation' not in obj.__dict__: @@ -163,7 +177,7 @@ def set_override(obj, prop, value): return obj -def get_view_method_names(view, schema=None): +def get_view_method_names(view, schema=None) -> List[str]: schema = schema or view.schema return [ item for item in dir(view) if callable(getattr(view, item)) and ( @@ -201,6 +215,6 @@ def wrapped_method(self, request, *args, **kwargs): return wrapped_method -def cache(user_function): +def cache(user_function: F) -> F: """ simple polyfill for python < 3.9 """ - return functools.lru_cache(maxsize=None)(user_function) + return functools.lru_cache(maxsize=None)(user_function) # type: ignore diff --git a/drf_spectacular/extensions.py b/drf_spectacular/extensions.py index c30b7a90..3f407a6c 100644 --- a/drf_spectacular/extensions.py +++ b/drf_spectacular/extensions.py @@ -10,6 +10,9 @@ from drf_spectacular.openapi import AutoSchema +_SchemaType = Dict[str, Any] + + class OpenApiAuthenticationExtension(OpenApiGeneratorExtension['OpenApiAuthenticationExtension']): """ Extension for specifying authentication schemes. @@ -29,7 +32,7 @@ class OpenApiAuthenticationExtension(OpenApiGeneratorExtension['OpenApiAuthentic ``get_security_definition()`` is expected to return a valid `OpenAPI security scheme object `_ """ - _registry: List['OpenApiAuthenticationExtension'] = [] + _registry: List[Type['OpenApiAuthenticationExtension']] = [] name: Union[str, List[str]] @@ -43,7 +46,7 @@ def get_security_requirement( return {name: [] for name in self.name} @abstractmethod - def get_security_definition(self, auto_schema: 'AutoSchema') -> Union[dict, List[dict]]: + def get_security_definition(self, auto_schema: 'AutoSchema') -> Union[_SchemaType, List[_SchemaType]]: pass # pragma: no cover @@ -59,13 +62,13 @@ class OpenApiSerializerExtension(OpenApiGeneratorExtension['OpenApiSerializerExt ``map_serializer()`` is expected to return a valid `OpenAPI schema object `_. """ - _registry: List['OpenApiSerializerExtension'] = [] + _registry: List[Type['OpenApiSerializerExtension']] = [] def get_name(self, auto_schema: 'AutoSchema', direction: Direction) -> Optional[str]: """ return str for overriding default name extraction """ return None - def map_serializer(self, auto_schema: 'AutoSchema', direction: Direction): + def map_serializer(self, auto_schema: 'AutoSchema', direction: Direction) -> _SchemaType: """ override for customized serializer mapping """ return auto_schema._map_serializer(self.target_class, direction, bypass_extensions=True) @@ -82,14 +85,14 @@ class OpenApiSerializerFieldExtension(OpenApiGeneratorExtension['OpenApiSerializ ``map_serializer_field()`` is expected to return a valid `OpenAPI schema object `_. """ - _registry: List['OpenApiSerializerFieldExtension'] = [] + _registry: List[Type['OpenApiSerializerFieldExtension']] = [] def get_name(self) -> Optional[str]: """ return str for breaking out field schema into separate named component """ return None @abstractmethod - def map_serializer_field(self, auto_schema: 'AutoSchema', direction: Direction): + def map_serializer_field(self, auto_schema: 'AutoSchema', direction: Direction) -> _SchemaType: """ override for customized serializer field mapping """ pass # pragma: no cover @@ -102,7 +105,7 @@ class OpenApiViewExtension(OpenApiGeneratorExtension['OpenApiViewExtension']): ``ViewSet`` et al.). The discovered original view instance can be accessed with ``self.target`` and be subclassed if desired. """ - _registry: List['OpenApiViewExtension'] = [] + _registry: List[Type['OpenApiViewExtension']] = [] @classmethod def _load_class(cls): @@ -129,8 +132,8 @@ class OpenApiFilterExtension(OpenApiGeneratorExtension['OpenApiFilterExtension'] Using ``drf_spectacular.plumbing.build_parameter_type`` is recommended to generate the appropriate raw dict objects. """ - _registry: List['OpenApiFilterExtension'] = [] + _registry: List[Type['OpenApiFilterExtension']] = [] @abstractmethod - def get_schema_operation_parameters(self, auto_schema: 'AutoSchema', *args, **kwargs) -> List[dict]: + def get_schema_operation_parameters(self, auto_schema: 'AutoSchema', *args, **kwargs) -> List[_SchemaType]: pass # pragma: no cover diff --git a/drf_spectacular/generators.py b/drf_spectacular/generators.py index ad71d35f..e3068845 100644 --- a/drf_spectacular/generators.py +++ b/drf_spectacular/generators.py @@ -3,17 +3,19 @@ from django.urls import URLPattern, URLResolver from rest_framework import views, viewsets -from rest_framework.schemas.generators import BaseSchemaGenerator # type: ignore +from rest_framework.schemas.generators import BaseSchemaGenerator from rest_framework.schemas.generators import EndpointEnumerator as BaseEndpointEnumerator from rest_framework.settings import api_settings -from drf_spectacular.drainage import add_trace_message, get_override, reset_generator_stats +from drf_spectacular.drainage import ( + add_trace_message, error, get_override, reset_generator_stats, warn, +) from drf_spectacular.extensions import OpenApiViewExtension from drf_spectacular.openapi import AutoSchema from drf_spectacular.plumbing import ( - ComponentRegistry, alpha_operation_sorter, build_root_object, camelize_operation, error, - get_class, is_versioning_supported, modify_for_versioning, normalize_result_object, - operation_matches_version, sanitize_result_object, warn, + ComponentRegistry, alpha_operation_sorter, build_root_object, camelize_operation, get_class, + is_versioning_supported, modify_for_versioning, normalize_result_object, + operation_matches_version, sanitize_result_object, ) from drf_spectacular.settings import spectacular_settings diff --git a/drf_spectacular/hooks.py b/drf_spectacular/hooks.py index 3b52e427..df9d30df 100644 --- a/drf_spectacular/hooks.py +++ b/drf_spectacular/hooks.py @@ -4,8 +4,9 @@ from inflection import camelize from rest_framework.settings import api_settings +from drf_spectacular.drainage import warn from drf_spectacular.plumbing import ( - ResolvedComponent, list_hash, load_enum_name_overrides, safe_ref, warn, + ResolvedComponent, list_hash, load_enum_name_overrides, safe_ref, ) from drf_spectacular.settings import spectacular_settings diff --git a/drf_spectacular/openapi.py b/drf_spectacular/openapi.py index 497fc606..ca1967bb 100644 --- a/drf_spectacular/openapi.py +++ b/drf_spectacular/openapi.py @@ -2,8 +2,8 @@ import functools import itertools import re -import typing from collections import defaultdict +from typing import Any, Dict, List, Optional, Union import uritemplate from django.core import exceptions as django_exceptions @@ -15,32 +15,37 @@ from rest_framework.generics import CreateAPIView, GenericAPIView, ListCreateAPIView from rest_framework.mixins import ListModelMixin from rest_framework.schemas.inspectors import ViewInspector -from rest_framework.schemas.utils import get_pk_description # type: ignore +from rest_framework.schemas.utils import get_pk_description from rest_framework.settings import api_settings from rest_framework.utils.model_meta import get_field_info from rest_framework.views import APIView -from drf_spectacular.authentication import OpenApiAuthenticationExtension +import drf_spectacular.authentication # noqa: F403, F401 +import drf_spectacular.serializers # noqa: F403, F401 from drf_spectacular.contrib import * # noqa: F403, F401 -from drf_spectacular.drainage import add_trace_message, get_override, has_override +from drf_spectacular.drainage import add_trace_message, error, get_override, has_override, warn from drf_spectacular.extensions import ( - OpenApiFilterExtension, OpenApiSerializerExtension, OpenApiSerializerFieldExtension, + OpenApiAuthenticationExtension, OpenApiFilterExtension, OpenApiSerializerExtension, + OpenApiSerializerFieldExtension, ) from drf_spectacular.plumbing import ( ComponentRegistry, ResolvedComponent, UnableToProceedError, append_meta, assert_basic_serializer, build_array_type, build_basic_type, build_choice_field, build_examples_list, build_generic_type, build_listed_example_value, build_media_type_object, - build_mocked_view, build_object_type, build_parameter_type, build_serializer_context, error, + build_mocked_view, build_object_type, build_parameter_type, build_serializer_context, filter_supported_arguments, follow_field_source, follow_model_field_lookup, force_instance, get_doc, get_list_serializer, get_manager, get_type_hints, get_view_model, is_basic_serializer, is_basic_type, is_field, is_list_serializer, is_list_serializer_customized, is_patched_serializer, is_serializer, is_trivial_string_variation, modify_media_types_for_versioning, resolve_django_path_parameter, resolve_regex_path_parameter, - resolve_type_hint, safe_ref, sanitize_specification_extensions, warn, whitelisted, + resolve_type_hint, safe_ref, sanitize_specification_extensions, whitelisted, ) from drf_spectacular.settings import spectacular_settings from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import OpenApiCallback, OpenApiParameter, OpenApiRequest, OpenApiResponse +from drf_spectacular.utils import ( + Direction, OpenApiCallback, OpenApiExample, OpenApiParameter, OpenApiRequest, OpenApiResponse, + _SchemaType, _SerializerType, +) class AutoSchema(ViewInspector): @@ -52,7 +57,14 @@ class AutoSchema(ViewInspector): 'delete': 'destroy', } - def get_operation(self, path, path_regex, path_prefix, method, registry: ComponentRegistry): + def get_operation( + self, + path: str, + path_regex: str, + path_prefix: str, + method: str, + registry: ComponentRegistry + ) -> Optional[_SchemaType]: self.registry = registry self.path = path self.path_regex = path_regex @@ -62,7 +74,7 @@ def get_operation(self, path, path_regex, path_prefix, method, registry: Compone if self.is_excluded(): return None - operation = {'operationId': self.get_operation_id()} + operation: _SchemaType = {'operationId': self.get_operation_id()} description = self.get_description() if description: @@ -108,11 +120,11 @@ def get_operation(self, path, path_regex, path_prefix, method, registry: Compone return operation - def is_excluded(self): + def is_excluded(self) -> bool: """ override this for custom behaviour """ return False - def _is_list_view(self, serializer=None): + def _is_list_view(self, serializer: Optional[_SerializerType] = None) -> bool: """ partially heuristic approach to determine if a view yields an object or a list of objects. used for operationId naming, array building and pagination. @@ -145,7 +157,7 @@ def _is_list_view(self, serializer=None): return False - def _is_create_operation(self): + def _is_create_operation(self) -> bool: if self.method != 'POST': return False if getattr(self.view, 'action', None) == 'create': @@ -154,7 +166,7 @@ def _is_create_operation(self): return True return False - def get_override_parameters(self): + def get_override_parameters(self) -> List[Union[OpenApiParameter, _SerializerType]]: """ override this for custom behaviour """ return [] @@ -221,19 +233,19 @@ def _process_override_parameters(self, direction='request'): warn(f'could not resolve parameter annotation {parameter}. Skipping.') return result - def _get_format_parameters(self): + def _get_format_parameters(self) -> List[_SchemaType]: parameters = [] formats = self.map_renderers('format') if api_settings.URL_FORMAT_OVERRIDE and len(formats) > 1: parameters.append(build_parameter_type( name=api_settings.URL_FORMAT_OVERRIDE, - schema=build_basic_type(OpenApiTypes.STR), + schema=build_basic_type(OpenApiTypes.STR), # type: ignore location=OpenApiParameter.QUERY, enum=formats )) return parameters - def _get_parameters(self): + def _get_parameters(self) -> List[_SchemaType]: def dict_helper(parameters): return {(p['name'], p['in']): p for p in parameters} @@ -275,29 +287,29 @@ def dict_helper(parameters): else: return list(parameters.values()) - def get_description(self): + def get_description(self) -> str: # type: ignore[override] """ override this for custom behaviour """ action_or_method = getattr(self.view, getattr(self.view, 'action', self.method.lower()), None) view_doc = get_doc(self.view.__class__) action_doc = get_doc(action_or_method) return action_doc or view_doc - def get_summary(self): + def get_summary(self) -> Optional[str]: """ override this for custom behaviour """ return None - def _get_external_docs(self): + def _get_external_docs(self) -> Optional[Dict[str, str]]: external_docs = self.get_external_docs() if isinstance(external_docs, str): return {'url': external_docs} else: return external_docs - def get_external_docs(self): + def get_external_docs(self) -> Optional[Union[Dict[str, str], str]]: """ override this for custom behaviour """ return None - def get_auth(self): + def get_auth(self) -> List[_SchemaType]: """ Obtains authentication classes and permissions from view. If authentication is known, resolve security requirement for endpoint and security definition for @@ -329,7 +341,7 @@ def get_auth(self): if isinstance(scheme.name, str): names, definitions = [scheme.name], [scheme.get_security_definition(self)] else: - names, definitions = scheme.name, scheme.get_security_definition(self) + names, definitions = scheme.name, scheme.get_security_definition(self) # type: ignore[assignment] for name, definition in zip(names, definitions): self.registry.register_on_missing( @@ -351,21 +363,21 @@ def get_auth(self): auths.append({}) return auths - def get_request_serializer(self) -> typing.Any: + def get_request_serializer(self) -> Optional[_SerializerType]: """ override this for custom behaviour """ return self._get_serializer() - def get_response_serializers(self) -> typing.Any: + def get_response_serializers(self) -> Optional[_SerializerType]: """ override this for custom behaviour """ return self._get_serializer() - def get_tags(self) -> typing.List[str]: + def get_tags(self) -> List[str]: """ override this for custom behaviour """ tokenized_path = self._tokenize_path() # use first non-parameter path part as tag return tokenized_path[:1] - def get_extensions(self) -> typing.Dict[str, typing.Any]: + def get_extensions(self) -> _SchemaType: return {} def _get_callbacks(self): @@ -426,11 +438,11 @@ def _get_callbacks(self): return result - def get_callbacks(self) -> typing.List[OpenApiCallback]: + def get_callbacks(self) -> List[OpenApiCallback]: """ override this for custom behaviour """ return [] - def get_operation_id(self): + def get_operation_id(self) -> str: """ override this for custom behaviour """ tokenized_path = self._tokenize_path() # replace dashes as they can be problematic later in code generation @@ -449,11 +461,11 @@ def get_operation_id(self): return '_'.join(tokenized_path + [action]) - def is_deprecated(self): + def is_deprecated(self) -> bool: """ override this for custom behaviour """ return False - def _tokenize_path(self): + def _tokenize_path(self) -> List[str]: # remove path prefix path = re.sub( pattern=self.path_prefix, @@ -464,8 +476,8 @@ def _tokenize_path(self): # remove path variables path = re.sub(pattern=r'\{[\w\-]+\}', repl='', string=path) # cleanup and tokenize remaining parts. - path = path.rstrip('/').lstrip('/').split('/') - return [t for t in path if t] + tokenized_path = path.rstrip('/').lstrip('/').split('/') + return [t for t in tokenized_path if t] def _resolve_path_parameters(self, variables): model = get_view_model(self.view, emit_warnings=False) @@ -519,7 +531,7 @@ def _resolve_path_parameters(self, variables): return parameters - def get_filter_backends(self): + def get_filter_backends(self) -> List[Any]: """ override this for custom behaviour """ if not self._is_list_view(): return [] @@ -895,7 +907,7 @@ def _map_serializer_field(self, field, direction, bypass_extensions=False): warn(f'could not resolve serializer field "{field}". Defaulting to "string"') return append_meta(build_basic_type(OpenApiTypes.STR), meta) - def _insert_min_max(self, field, content): + def _insert_min_max(self, field: Any, content: _SchemaType) -> None: if field.max_value is not None: content['maximum'] = field.max_value if 'exclusiveMaximum' in content: @@ -1131,16 +1143,16 @@ def _get_paginator(self): return pagination_class() return None - def get_paginated_name(self, serializer_name): + def get_paginated_name(self, serializer_name: str) -> str: return f'Paginated{serializer_name}List' - def map_parsers(self): + def map_parsers(self) -> List[Any]: return list(dict.fromkeys([ p.media_type for p in self.view.get_parsers() if whitelisted(p, spectacular_settings.PARSER_WHITELIST) ])) - def map_renderers(self, attribute): + def map_renderers(self, attribute: str) -> List[Any]: assert attribute in ['media_type', 'format'] # Either use whitelist or default back to old behavior by excluding BrowsableAPIRenderer @@ -1190,7 +1202,7 @@ def _get_serializer(self): f'a request? Ignoring the view for now. (Exception: {exc})' ) - def get_examples(self): + def get_examples(self) -> List[OpenApiExample]: """ override this for custom behaviour """ return [] @@ -1340,7 +1352,7 @@ def _get_request_for_media_type(self, serializer, direction='request'): request_body_required = False return schema, request_body_required - def _get_response_bodies(self, direction='response'): + def _get_response_bodies(self, direction: Direction = 'response') -> _SchemaType: response_serializers = self.get_response_serializers() if ( @@ -1374,10 +1386,10 @@ def _get_response_bodies(self, direction='response'): f'Defaulting to generic free-form object.' ) schema = build_basic_type(OpenApiTypes.OBJECT) - schema['description'] = _('Unspecified response body') + schema['description'] = _('Unspecified response body') # type: ignore return {'200': self._get_response_for_code(schema, '200', direction=direction)} - def _unwrap_list_serializer(self, serializer, direction) -> typing.Optional[dict]: + def _unwrap_list_serializer(self, serializer, direction: Direction) -> Optional[_SchemaType]: if is_field(serializer): return self._map_serializer_field(serializer, direction) elif is_basic_serializer(serializer): @@ -1479,7 +1491,7 @@ def _get_response_for_code(self, serializer, status_code, media_types=None, dire 'description': description } - def _get_response_headers_for_code(self, status_code, direction='response') -> dict: + def _get_response_headers_for_code(self, status_code, direction='response') -> _SchemaType: result = {} for parameter in self.get_override_parameters(): if not isinstance(parameter, OpenApiParameter): @@ -1497,7 +1509,11 @@ def _get_response_headers_for_code(self, status_code, direction='response') -> d elif is_serializer(parameter.type): schema = self.resolve_serializer(parameter.type, direction).ref else: - schema = parameter.type + schema = parameter.type # type: ignore + + if not schema: + warn(f'response parameter {parameter.name} requires non-empty schema') + continue if parameter.location not in [OpenApiParameter.HEADER, OpenApiParameter.COOKIE]: warn(f'incompatible location type ignored for response parameter {parameter.name}') @@ -1524,10 +1540,10 @@ def _get_response_headers_for_code(self, status_code, direction='response') -> d return result - def get_serializer_name(self, serializer, direction): + def get_serializer_name(self, serializer: serializers.Serializer, direction: Direction) -> str: return serializer.__class__.__name__ - def _get_serializer_name(self, serializer, direction, bypass_extensions=False): + def _get_serializer_name(self, serializer, direction, bypass_extensions=False) -> str: serializer_extension = OpenApiSerializerExtension.get_match(serializer) if serializer_extension and not bypass_extensions: custom_name = serializer_extension.get_name(**filter_supported_arguments( @@ -1550,6 +1566,8 @@ def _get_serializer_name(self, serializer, direction, bypass_extensions=False): else: name = self.get_serializer_name(serializer, direction) + assert name + if name.endswith('Serializer'): name = name[:-10] @@ -1568,7 +1586,9 @@ def _get_serializer_name(self, serializer, direction, bypass_extensions=False): return name - def resolve_serializer(self, serializer, direction, bypass_extensions=False) -> ResolvedComponent: + def resolve_serializer( + self, serializer: _SerializerType, direction: Direction, bypass_extensions=False + ) -> ResolvedComponent: assert_basic_serializer(serializer) serializer = force_instance(serializer) diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index 3a31252b..ae47f236 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -12,7 +12,15 @@ from collections import OrderedDict, defaultdict from decimal import Decimal from enum import Enum -from typing import Any, DefaultDict, Generic, List, Optional, Tuple, Type, TypeVar, Union +from typing import ( + Any, DefaultDict, Dict, Generic, List, Optional, Sequence, Tuple, Type, TypeVar, Union, +) + +if sys.version_info >= (3, 10): + from typing import TypeGuard # noqa: F401 +else: + from typing_extensions import TypeGuard # noqa: F401 + import inflection import uritemplate @@ -46,8 +54,12 @@ from drf_spectacular.settings import spectacular_settings from drf_spectacular.types import ( DJANGO_PATH_CONVERTER_MAPPING, OPENAPI_TYPE_MAPPING, PYTHON_TYPE_MAPPING, OpenApiTypes, + _KnownPythonTypes, +) +from drf_spectacular.utils import ( + OpenApiExample, OpenApiParameter, _FieldType, _ListSerializerType, _ParameterLocationType, + _SchemaType, _SerializerType, ) -from drf_spectacular.utils import OpenApiParameter try: from django.db.models.enums import Choices # only available in Django>3 @@ -57,9 +69,9 @@ class Choices: # type: ignore # types.UnionType was added in Python 3.10 for new PEP 604 pipe union syntax if hasattr(types, 'UnionType'): - UNION_TYPES: Tuple[Any, ...] = (typing.Union, types.UnionType) # type: ignore + UNION_TYPES: Tuple[Any, ...] = (Union, types.UnionType) else: - UNION_TYPES = (typing.Union,) + UNION_TYPES = (Union,) LITERAL_TYPES: Tuple[Any, ...] = () TYPED_DICT_META_TYPES: Tuple[Any, ...] = () @@ -79,9 +91,9 @@ class Choices: # type: ignore pass if sys.version_info >= (3, 8): - CACHED_PROPERTY_FUNCS = (functools.cached_property, cached_property) # type: ignore + CACHED_PROPERTY_FUNCS = (functools.cached_property, cached_property) else: - CACHED_PROPERTY_FUNCS = (cached_property,) # type: ignore + CACHED_PROPERTY_FUNCS = (cached_property,) T = TypeVar('T') @@ -103,19 +115,20 @@ def force_instance(serializer_or_field): return serializer_or_field -def is_serializer(obj, strict=False) -> bool: - from drf_spectacular.serializers import OpenApiSerializerExtension +def is_serializer(obj, strict=False) -> TypeGuard[_SerializerType]: + + from drf_spectacular.extensions import OpenApiSerializerExtension return ( isinstance(force_instance(obj), serializers.BaseSerializer) or (bool(OpenApiSerializerExtension.get_match(obj)) and not strict) ) -def is_list_serializer(obj) -> bool: +def is_list_serializer(obj: Any) -> TypeGuard[_ListSerializerType]: return isinstance(force_instance(obj), serializers.ListSerializer) -def get_list_serializer(obj): +def get_list_serializer(obj: Any): return force_instance(obj) if is_list_serializer(obj) else get_class(obj)(many=True, context=obj.context) @@ -127,17 +140,17 @@ def is_list_serializer_customized(obj) -> bool: ) -def is_basic_serializer(obj) -> bool: +def is_basic_serializer(obj: Any) -> TypeGuard[_SerializerType]: return is_serializer(obj) and not is_list_serializer(obj) -def is_field(obj): +def is_field(obj: Any) -> TypeGuard[_FieldType]: # make sure obj is a serializer field and nothing else. # guard against serializers because BaseSerializer(Field) return isinstance(force_instance(obj), fields.Field) and not is_serializer(obj) -def is_basic_type(obj, allow_none=True): +def is_basic_type(obj: Any, allow_none=True) -> TypeGuard[_KnownPythonTypes]: if not isinstance(obj, collections.abc.Hashable): return False if not allow_none and (obj is None or obj is OpenApiTypes.NONE): @@ -145,7 +158,7 @@ def is_basic_type(obj, allow_none=True): return obj in get_openapi_type_mapping() or obj in PYTHON_TYPE_MAPPING -def is_patched_serializer(serializer, direction): +def is_patched_serializer(serializer, direction) -> bool: return bool( spectacular_settings.COMPONENT_SPLIT_PATCH and getattr(serializer, 'partial', None) @@ -154,13 +167,13 @@ def is_patched_serializer(serializer, direction): ) -def is_trivial_string_variation(a: str, b: str): +def is_trivial_string_variation(a: str, b: str) -> bool: a = (a or '').strip().lower().replace(' ', '_').replace('-', '_') b = (b or '').strip().lower().replace(' ', '_').replace('-', '_') return a == b -def assert_basic_serializer(serializer): +def assert_basic_serializer(serializer) -> None: assert is_basic_serializer(serializer), ( f'internal assumption violated because we expected a basic serializer here and ' f'instead got a "{serializer}". This may be the result of another app doing ' @@ -208,9 +221,9 @@ def get_view_model(view, emit_warnings=True): ) -def get_doc(obj): +def get_doc(obj) -> str: """ get doc string with fallback on obj's base classes (ignoring DRF documentation). """ - def post_cleanup(doc: str): + def post_cleanup(doc: str) -> str: # also clean up trailing whitespace for each line return '\n'.join(line.rstrip() for line in doc.rstrip().split('\n')) @@ -232,7 +245,7 @@ def safe_index(lst, item): return '' -def get_type_hints(obj): +def get_type_hints(obj) -> Dict[str, Any]: """ unpack wrapped partial object and use actual func object """ if isinstance(obj, functools.partial): obj = obj.func @@ -266,7 +279,7 @@ def build_generic_type(): return {'type': 'object', 'additionalProperties': {}} -def build_basic_type(obj): +def build_basic_type(obj: Union[_KnownPythonTypes, OpenApiTypes]) -> Optional[_SchemaType]: """ resolve either enum or actual type and yield schema template for modification """ @@ -282,7 +295,7 @@ def build_basic_type(obj): return dict(openapi_type_mapping[OpenApiTypes.STR]) -def build_array_type(schema, min_length=None, max_length=None): +def build_array_type(schema: _SchemaType, min_length=None, max_length=None) -> _SchemaType: schema = {'type': 'array', 'items': schema} if min_length is not None: schema['minLength'] = min_length @@ -292,12 +305,12 @@ def build_array_type(schema, min_length=None, max_length=None): def build_object_type( - properties=None, + properties: Optional[_SchemaType] = None, required=None, - description=None, + description: Optional[str] = None, **kwargs -): - schema = {'type': 'object'} +) -> _SchemaType: + schema: _SchemaType = {'type': 'object'} if description: schema['description'] = description.strip() if properties: @@ -310,7 +323,7 @@ def build_object_type( return schema -def build_media_type_object(schema, examples=None, encoding=None): +def build_media_type_object(schema, examples=None, encoding=None) -> _SchemaType: media_type_object = {'schema': schema} if examples: media_type_object['examples'] = examples @@ -319,7 +332,7 @@ def build_media_type_object(schema, examples=None, encoding=None): return media_type_object -def build_examples_list(examples): +def build_examples_list(examples: Sequence[OpenApiExample]) -> _SchemaType: schema = {} for example in examples: normalized_name = inflection.camelize(example.name.replace(' ', '_')) @@ -339,9 +352,9 @@ def build_examples_list(examples): def build_parameter_type( - name, - schema, - location, + name: str, + schema: _SchemaType, + location: _ParameterLocationType, required=False, description=None, enum=None, @@ -353,7 +366,7 @@ def build_parameter_type( allow_blank=True, examples=None, extensions=None, -): +) -> _SchemaType: irrelevant_field_meta = ['readOnly', 'writeOnly'] if location == OpenApiParameter.PATH: irrelevant_field_meta += ['nullable', 'default'] @@ -395,11 +408,11 @@ def build_parameter_type( return schema -def build_choice_field(field): +def build_choice_field(field) -> _SchemaType: choices = list(OrderedDict.fromkeys(field.choices)) # preserve order and remove duplicates if all(isinstance(choice, bool) for choice in choices): - type = 'boolean' + type: Optional[str] = 'boolean' elif all(isinstance(choice, int) for choice in choices): type = 'integer' elif all(isinstance(choice, (int, float, Decimal)) for choice in choices): # `number` includes `integer` @@ -415,7 +428,7 @@ def build_choice_field(field): if field.allow_null and None not in choices: choices.append(None) - schema = { + schema: _SchemaType = { # The value of `enum` keyword MUST be an array and SHOULD be unique. # Ref: https://tools.ietf.org/html/draft-wright-json-schema-validation-00#section-5.20 'enum': choices @@ -464,7 +477,7 @@ def build_bearer_security_scheme_object(header_name, token_prefix, bearer_format } -def build_root_object(paths, components, version): +def build_root_object(paths, components, version) -> _SchemaType: settings = spectacular_settings if settings.VERSION and version: version = f'{settings.VERSION} ({version})' @@ -498,7 +511,7 @@ def build_root_object(paths, components, version): return root -def safe_ref(schema): +def safe_ref(schema: _SchemaType) -> _SchemaType: """ ensure that $ref has its own context and does not remove potential sibling entries when $ref is substituted. also remove useless singular "allOf" . @@ -510,7 +523,7 @@ def safe_ref(schema): return schema -def append_meta(schema, meta): +def append_meta(schema: _SchemaType, meta: _SchemaType) -> _SchemaType: if spectacular_settings.OAS_VERSION.startswith('3.1'): schema_nullable = meta.pop('nullable', None) meta_nullable = schema.pop('nullable', None) @@ -554,9 +567,9 @@ def _follow_field_source(model, path: List[str]): elif isinstance(field_or_property, ReverseOneToOneDescriptor): return field_or_property.related.target_field # o2o reverse elif isinstance(field_or_property, ReverseManyToOneDescriptor): - return field_or_property.rel.target_field # type: ignore # foreign reverse + return field_or_property.rel.target_field # foreign reverse elif isinstance(field_or_property, ForwardManyToOneDescriptor): - return field_or_property.field.target_field # type: ignore # o2o & foreign forward + return field_or_property.field.target_field # o2o & foreign forward else: field = model._meta.get_field(path[0]) if isinstance(field, ForeignObjectRel): @@ -676,20 +689,20 @@ def __bool__(self): return bool(self.name and self.type and self.object) @property - def key(self): + def key(self) -> Tuple[str, str]: return self.name, self.type @property - def ref(self) -> dict: + def ref(self) -> _SchemaType: assert self.__bool__() return {'$ref': f'#/components/{self.type}/{self.name}'} class ComponentRegistry: - def __init__(self): - self._components = {} + def __init__(self) -> None: + self._components: Dict[Tuple[str, str], ResolvedComponent] = {} - def register(self, component: ResolvedComponent): + def register(self, component: ResolvedComponent) -> None: if component in self: warn( f'trying to re-register a {component.type} component with name ' @@ -698,7 +711,7 @@ def register(self, component: ResolvedComponent): ) self._components[component.key] = component - def register_on_missing(self, component: ResolvedComponent): + def register_on_missing(self, component: ResolvedComponent) -> None: if component not in self: self._components[component.key] = component @@ -723,7 +736,7 @@ def __contains__(self, component): ) return True - def __getitem__(self, key): + def __getitem__(self, key) -> ResolvedComponent: if isinstance(key, ResolvedComponent): key = key.key return self._components[key] @@ -733,8 +746,8 @@ def __delitem__(self, key): key = key.key del self._components[key] - def build(self, extra_components) -> dict: - output: DefaultDict[str, dict] = defaultdict(dict) + def build(self, extra_components) -> _SchemaType: + output: DefaultDict[str, _SchemaType] = defaultdict(dict) # build tree from flat registry for component in self._components.values(): output[component.type][component.name] = component.schema @@ -750,7 +763,7 @@ def build(self, extra_components) -> dict: class OpenApiGeneratorExtension(Generic[T], metaclass=ABCMeta): - _registry: List[T] = [] + _registry: List[Type[T]] = [] target_class: Union[None, str, Type[object]] = None match_subclasses = False priority = 0 @@ -785,7 +798,7 @@ def _load_class(cls): cls.target_class = None @classmethod - def _matches(cls, target) -> bool: + def _matches(cls, target: Any) -> bool: if isinstance(cls.target_class, str): cls._load_class() @@ -808,7 +821,7 @@ def get_match(cls, target) -> Optional[T]: return None -def deep_import_string(string): +def deep_import_string(string: str) -> Any: """ augmented import from string, e.g. MODULE.CLASS/OBJECT.ATTRIBUTE """ try: return import_string(string) @@ -864,7 +877,7 @@ def load_enum_name_overrides(): return overrides -def list_hash(lst): +def list_hash(lst: Any) -> str: return hashlib.sha256(json.dumps(list(lst), sort_keys=True, cls=JSONEncoder).encode()).hexdigest()[:16] @@ -965,7 +978,7 @@ def resolve_regex_path_parameter(path_regex, variable): return None -def is_versioning_supported(versioning_class): +def is_versioning_supported(versioning_class) -> bool: return issubclass(versioning_class, ( versioning.URLPathVersioning, versioning.NamespaceVersioning, @@ -973,7 +986,7 @@ def is_versioning_supported(versioning_class): )) -def operation_matches_version(view, requested_version): +def operation_matches_version(view, requested_version) -> bool: try: version, _ = view.determine_version(view.request, **view.kwargs) except exceptions.NotAcceptable: @@ -1039,7 +1052,7 @@ def modify_media_types_for_versioning(view, media_types: List[str]) -> List[str] ] -def analyze_named_regex_pattern(path): +def analyze_named_regex_pattern(path: str) -> Dict[str, str]: """ safely extract named groups and their pattern from given regex pattern """ result = {} stack = 0 @@ -1320,7 +1333,7 @@ def resolve_type_hint(hint): raise UnableToProceedError() -def whitelisted(obj: object, classes: Optional[List[Type[object]]], exact=False): +def whitelisted(obj: object, classes: Optional[List[Type[object]]], exact=False) -> bool: if classes is None: return True if exact: @@ -1338,7 +1351,7 @@ class TmpView(views.APIView): # emulate what Generator would do to setup schema generation. view_callable = TmpView.as_view() - view = view_callable.cls() # type: ignore + view: views.APIView = view_callable.cls() view.request = spectacular_settings.GET_MOCK_REQUEST( method.upper(), path, view, None ) diff --git a/drf_spectacular/utils.py b/drf_spectacular/utils.py index 55fc0389..da009d0a 100644 --- a/drf_spectacular/utils.py +++ b/drf_spectacular/utils.py @@ -11,7 +11,7 @@ from typing_extensions import Final, Literal from rest_framework.fields import Field, empty -from rest_framework.serializers import Serializer +from rest_framework.serializers import ListSerializer, Serializer from rest_framework.settings import api_settings from drf_spectacular.drainage import ( @@ -19,10 +19,12 @@ ) from drf_spectacular.types import OpenApiTypes, _KnownPythonTypes +_ListSerializerType = Union[ListSerializer, Type[ListSerializer]] _SerializerType = Union[Serializer, Type[Serializer]] _FieldType = Union[Field, Type[Field]] _ParameterLocationType = Literal['query', 'path', 'header', 'cookie'] _StrOrPromise = Union[str, Promise] +_SchemaType = Dict[str, Any] Direction = Literal['request', 'response'] @@ -202,7 +204,7 @@ class OpenApiParameter(OpenApiSchemaBase): def __init__( self, name: str, - type: Union[_SerializerType, _KnownPythonTypes, OpenApiTypes, dict] = str, + type: Union[_SerializerType, _KnownPythonTypes, OpenApiTypes, _SchemaType] = str, location: _ParameterLocationType = QUERY, required: bool = False, description: _StrOrPromise = '', @@ -326,7 +328,7 @@ def extend_schema( tags: Optional[Sequence[str]] = None, filters: Optional[bool] = None, exclude: Optional[bool] = None, - operation: Optional[Dict] = None, + operation: Optional[_SchemaType] = None, methods: Optional[Sequence[str]] = None, versions: Optional[Sequence[str]] = None, examples: Optional[Sequence[OpenApiExample]] = None, @@ -542,7 +544,7 @@ def get_external_docs(self): def extend_schema_field( - field: Union[_SerializerType, _FieldType, OpenApiTypes, Dict], + field: Union[_SerializerType, _FieldType, OpenApiTypes, _SchemaType], component_name: Optional[str] = None ) -> Callable[[F], F]: """ diff --git a/requirements/base.txt b/requirements/base.txt index 39a0c1a5..04f5f0e7 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -4,4 +4,4 @@ uritemplate>=2.0.0 PyYAML>=5.1 jsonschema>=2.6.0 inflection>=0.3.1 -typing-extensions; python_version < "3.8" \ No newline at end of file +typing-extensions; python_version < "3.10" \ No newline at end of file diff --git a/requirements/docs.txt b/requirements/docs.txt index 486e4406..ca42c48f 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,2 +1,3 @@ Sphinx>=4.1.0 sphinx_rtd_theme>=0.5.1 +typing-extensions diff --git a/tests/contrib/test_drf_spectacular_sidecar.py b/tests/contrib/test_drf_spectacular_sidecar.py index 1287944f..3baa1c1e 100644 --- a/tests/contrib/test_drf_spectacular_sidecar.py +++ b/tests/contrib/test_drf_spectacular_sidecar.py @@ -34,7 +34,7 @@ def test_sidecar_shortcut_urls_are_resolved(no_warnings): def test_sidecar_package_urls_matching(no_warnings): # poor man's test to make sure the sidecar package contents match with what # collectstatic is going to compile. cannot be tested directly. - import drf_spectacular_sidecar # type: ignore[import] + import drf_spectacular_sidecar # type: ignore[import-not-found] module_root = os.path.dirname(inspect.getfile(drf_spectacular_sidecar)) bundle_path = os.path.join(module_root, BUNDLE_URL) assert os.path.isfile(bundle_path) diff --git a/tests/contrib/test_pydantic.py b/tests/contrib/test_pydantic.py index 61c5e3da..11672fab 100644 --- a/tests/contrib/test_pydantic.py +++ b/tests/contrib/test_pydantic.py @@ -14,7 +14,7 @@ class BaseModel: # type: ignore pass - def dataclass(f): # type: ignore + def dataclass(f): return f diff --git a/tests/test_fields.py b/tests/test_fields.py index 033a09ae..1799ac39 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -93,7 +93,7 @@ class AllFields(models.Model): if DJANGO_VERSION > '3.1': field_nullbool = models.BooleanField(null=True) else: - field_nullbool = models.NullBooleanField() # type: ignore + field_nullbool = models.NullBooleanField() field_time = models.TimeField() field_duration = models.DurationField() field_binary = models.BinaryField() @@ -202,7 +202,7 @@ def get_field_method_object(self, obj) -> dict: ) # type: ignore field_related_slug_queryset = serializers.SlugRelatedField( source='field_foreign', slug_field='url', queryset=Aux.objects.all() - ) # type: ignore + ) field_related_slug_many = serializers.SlugRelatedField( many=True, read_only=True, source='field_m2m', slug_field='url', ) # type: ignore diff --git a/tox.ini b/tox.ini index 7286a549..d5cb2ead 100644 --- a/tox.ini +++ b/tox.ini @@ -97,12 +97,29 @@ use_parentheses = true include_trailing_comma = true [mypy] -python_version = 3.8 +python_version = 3.10 plugins = mypy_django_plugin.main,mypy_drf_plugin.main +warn_unused_configs = True +warn_redundant_casts = True +warn_unused_ignores = True [mypy.plugins.django-stubs] django_settings_module = "tests.settings" +[mypy-drf_spectacular.*] +strict_equality = True +no_implicit_optional = True +disallow_untyped_decorators = True +disallow_subclassing_any = True +no_implicit_reexport = True +;check_untyped_defs = True +;warn_return_any = True +;disallow_incomplete_defs = True +;disallow_any_generics = True +;disallow_untyped_calls = True +;disallow_untyped_defs = True + + [mypy-rest_framework.compat.*] ignore_missing_imports = True