From 5a82479a90e25ad379df264c2b62e53bba2fc92c Mon Sep 17 00:00:00 2001 From: David Salvisberg Date: Thu, 7 Dec 2023 09:47:37 +0100 Subject: [PATCH 1/2] feat: Adds support for SQLAlchemy 2.0 `Mapped` annotations --- README.md | 32 +++ flake8_type_checking/checker.py | 349 ++++++++++++++++++++++++++---- flake8_type_checking/constants.py | 8 +- flake8_type_checking/plugin.py | 14 ++ setup.cfg | 2 + tests/conftest.py | 2 + tests/test_import_visitors.py | 2 + tests/test_name_visitor.py | 2 + tests/test_sqlalchemy.py | 163 ++++++++++++++ 9 files changed, 534 insertions(+), 40 deletions(-) create mode 100644 tests/test_sqlalchemy.py diff --git a/README.md b/README.md index d334554..13d9c34 100644 --- a/README.md +++ b/README.md @@ -226,6 +226,38 @@ Enabling dependency support will also enable FastAPI and Pydantic support. type-checking-fastapi-dependency-support-enabled = true # default false ``` +### SQLAlchemy 2.0+ support + +If you're using SQLAlchemy 2.0+, you can enable support. +This will treat any `Mapped[...]` types as needed at runtime. +It will also specially treat the enclosed type, since it may or may not +need to be available at runtime depending on whether or not the enclosed +type is a model or not, since model can have circular dependencies. + +- **name**: `type-checking-sqlalchemy-enabled` +- **type**: `bool` +```ini +type-checking-sqlalchemy-enabled = true # default false +``` + +### SQLAlchemy 2.0+ support mapped dotted names + +Since it's possible to create subclasses of `sqlalchemy.orm.Mapped` that +define some custom behavior for the mapped attribute, but otherwise still +behave like any other mapped attribute, i.e. the same runtime restrictions +apply it's possible to provide additional dotted names that should be treated +like subclasses of `Mapped`. By default we check for `sqlalchemy.orm.Mapped`, +`sqlalchemy.orm.DynamicMapped` and `sqlalchemy.orm.WriteOnlyMapped`. + +If there's more than one import path for the same `Mapped` subclass, then you +need to specify each of them as a separate dotted name. + +- **name**: `type-checking-sqlalchemy-mapped-dotted-names` +- **type**: `list` +```ini +type-checking-sqlalchemy-mapped-dotted-names = a.MyMapped, a.b.MyMapped # default [] +``` + ### Cattrs support If you're using the plugin in a project which uses `cattrs`, diff --git a/flake8_type_checking/checker.py b/flake8_type_checking/checker.py index 7a01ddd..06fd76d 100644 --- a/flake8_type_checking/checker.py +++ b/flake8_type_checking/checker.py @@ -4,6 +4,7 @@ import fnmatch import os import sys +from abc import ABC, abstractmethod from ast import Index, literal_eval from collections import defaultdict from contextlib import contextmanager, suppress @@ -19,7 +20,6 @@ ATTRIBUTE_PROPERTY, ATTRS_DECORATORS, ATTRS_IMPORTS, - DUNDER_ALL_PROPERTY, NAME_RE, TC001, TC002, @@ -36,6 +36,7 @@ TC201, builtin_names, py38, + sqlalchemy_default_mapped_dotted_names, ) try: @@ -67,6 +68,41 @@ def ast_unparse(node: ast.AST) -> str: ) +class AnnotationVisitor(ABC): + """Simplified node visitor for traversing annotations.""" + + @abstractmethod + def visit_annotation_name(self, node: ast.Name) -> None: + """Visit a name inside an annotation.""" + + @abstractmethod + def visit_annotation_string(self, node: ast.Constant) -> None: + """Visit a string constant inside an annotation.""" + + def visit(self, node: ast.AST) -> None: + """Visit relevant child nodes on an annotation.""" + if node is None: + return + if isinstance(node, ast.BinOp): + if not isinstance(node.op, ast.BitOr): + return + self.visit(node.left) + self.visit(node.right) + elif (py38 and isinstance(node, Index)) or isinstance(node, ast.Attribute): + self.visit(node.value) + elif isinstance(node, ast.Subscript): + self.visit(node.value) + if getattr(node.value, 'id', '') != 'Literal': + self.visit(node.slice) + elif isinstance(node, (ast.Tuple, ast.List)): + for n in node.elts: + self.visit(n) + elif isinstance(node, ast.Constant) and isinstance(node.value, str): + self.visit_annotation_string(node) + elif isinstance(node, ast.Name): + self.visit_annotation_name(node) + + class AttrsMixin: """ Contains necessary logic for cattrs + attrs support. @@ -152,6 +188,7 @@ def generic_visit(self, node: ast.AST) -> None: # noqa: D102 ... def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self.__all___assignments: list[tuple[int, int]] = [] def in___all___declaration(self, node: ast.Constant) -> bool: @@ -201,8 +238,9 @@ def visit_Constant(self, node: ast.Constant) -> ast.Constant: """Map constant as use, if we're inside an __all__ declaration.""" if self.in___all___declaration(node): # for these it doesn't matter where they are declared, the symbol - # just needs to be available in global scope anywhere - setattr(node, DUNDER_ALL_PROPERTY, True) + # just needs to be available in global scope anywhere, we handle + # this by special casing `ast.Constant` when we look for used type + # checking symbols self.uses[node.value].append((node, self.current_scope)) return node @@ -246,6 +284,207 @@ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None: self.visit(argument.annotation) +class SQLAlchemyAnnotationVisitor(AnnotationVisitor): + """Adds any names in the annotation to mapped names.""" + + def __init__(self, mapped_names: set[str]) -> None: + self.mapped_names = mapped_names + + def visit_annotation_name(self, node: ast.Name) -> None: + """Add name to mapped names.""" + self.mapped_names.add(node.id) + + def visit_annotation_string(self, node: ast.Constant) -> None: + """Add all the names in the string to mapped names.""" + self.mapped_names.update(NAME_RE.findall(node.value)) + + +class SQLAlchemyMixin: + """ + Contains the necessary logic for SQLAlchemy (https://www.sqlalchemy.org/) support. + + For mapped attributes we don't know whether or not the enclosed type needs to be + available at runtime, since it may or may not be a model. So we treat it like a + runtime use for the purposes of TC001/TC002/TC003 but not anywhere else. + + This supports `sqlalchemy.orm.Mapped`, `sqlalchemy.orm.DynamicMapped` and + `sqlalchemy.orm.WriteOnlyMapped` by default, but can be extended to cover + additional user-defined subclassed of `sqlalchemy.orm.Mapped` using the + setting `type-checking-sqlalchemy-mapped-dotted-names`. + """ + + if TYPE_CHECKING: + sqlalchemy_enabled: bool + sqlalchemy_mapped_dotted_names: set[str] + current_scope: Scope + uses: dict[str, list[tuple[ast.AST, Scope]]] + + def visit(self, node: ast.AST) -> ast.AST: # noqa: D102 + ... + + def in_type_checking_block(self, lineno: int, col_offset: int) -> bool: # noqa: D102 + ... + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + #: Contains a set of all the `Mapped` names that have been imported + self.mapped_aliases: set[str] = set() + + #: Contains a dictionary of aliases to one of our dotted names containing one of the `Mapped` names + self.mapped_module_aliases: dict[str, str] = {} + + #: Contains a set of all names used inside `Mapped[...]` annotations + # These are treated like soft-uses, i.e. we don't know if it will be + # used at runtime or not + self.mapped_names: set[str] = set() + + #: Used for visiting annotations + self.sqlalchemy_annotation_visitor = SQLAlchemyAnnotationVisitor(self.mapped_names) + + def visit_Import(self, node: ast.Import) -> None: + """Record `Mapped` module aliases.""" + if not self.sqlalchemy_enabled: + return + + for name in node.names: + # we don't need to map anything + if name.asname is None: + continue + + prefix = name.name + '.' + for dotted_name in self.sqlalchemy_mapped_dotted_names: + if dotted_name.startswith(prefix): + self.mapped_module_aliases[name.asname] = name.name + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + """Record `Mapped` aliases.""" + if not self.sqlalchemy_enabled: + return + + # we don't try to deal with relative imports + if node.module is None or node.level != 0: + return + + prefix = node.module + '.' + for dotted_name in self.sqlalchemy_mapped_dotted_names: + if not dotted_name.startswith(prefix): + continue + + suffix = dotted_name[len(prefix) :] + if '.' in suffix: + # we may need to record a mapped module + for name in node.names: + if suffix.startswith(name.name + '.'): + self.mapped_module_aliases[name.asname or name.name] = node.module + elif name.name == '*': + # in this case we can assume that the next segment of the dotted + # name has been imported + self.mapped_module_aliases[suffix.split('.', 1)[0]] = node.module + continue + + # we may need to record a mapped name + for name in node.names: + if suffix == name.name: + self.mapped_aliases.add(name.asname or name.name) + elif name.name == '*': + # in this case we can assume it has been imported + self.mapped_aliases.add(suffix) + + def visit_AnnAssign(self, node: ast.AnnAssign) -> None: + """Remove all annotations assigments.""" + if ( + self.sqlalchemy_enabled + # We only need to special case runtime use of `Mapped` + and not self.in_type_checking_block(node.lineno, node.col_offset) + ): + self.handle_sqlalchemy_annotation(node.annotation) + + def handle_sqlalchemy_annotation(self, node: ast.AST) -> None: + """ + Record all the mapped names inside the annotations. + + If the annotation is an `ast.Constant` that starts with one of our + `Mapped` names, then we will record a runtime use of that symbol, + since we know `Mapped` always needs to resolve. + """ + if isinstance(node, ast.Constant): + # we only need to handle annotations like `"Mapped[...]"` + if not isinstance(node.value, str) or '[' not in node.value: + return + + annotation = node.value.strip() + if not annotation.endswith(']'): + return + + # if we ever do more sophisticated parsing of text annotations + # then we would want to strip the trailing `]` from inner, but + # with our simple parsing we don't care + mapped_name, inner = annotation.split('[', 1) + if mapped_name in self.mapped_aliases: + # record a use for the name + self.uses[mapped_name].append((node, self.current_scope)) + + elif mapped_name in self.sqlalchemy_mapped_dotted_names: + # record a use for the left-most part of the dotted name + self.uses[mapped_name.split('.', 1)[0]].append((node, self.current_scope)) + + elif '.' in annotation: + # try to resolve to a mapped module alias + aliased, mapped_name = mapped_name.split('.', 1) + module = self.mapped_module_aliases.get(aliased) + if module is None or f'{module}.{mapped_name}' not in self.sqlalchemy_mapped_dotted_names: + return + + # record a use for the alias + self.uses[aliased].append((node, self.current_scope)) + + # add all names contained in the inner part of the annotation + # since this is not as strict as an actual runtime use, we don't + # care if we record too much here + self.mapped_names.update(NAME_RE.findall(inner)) + return + + # we only need to handle annotations like `Mapped[...]` + if not isinstance(node, ast.Subscript): + return + + # simple case only needs to check mapped_aliases + if isinstance(node.value, ast.Name): + if node.value.id not in self.mapped_aliases: + return + + # record a use for the name + self.uses[node.value.id].append((node.value, self.current_scope)) + + # complex case for dotted names + elif isinstance(node.value, ast.Attribute): + dotted_name = node.value.attr + before_dot = node.value.value + while isinstance(before_dot, ast.Attribute): + dotted_name = f'{before_dot.attr}.{dotted_name}' + before_dot = before_dot.value + # there should be no subscripts between the attributes + if not isinstance(before_dot, ast.Name): + return + + # map the module if it's mapped otherwise use it as is + module = self.mapped_module_aliases.get(before_dot.id, before_dot.id) + dotted_name = f'{module}.{dotted_name}' + if dotted_name not in self.sqlalchemy_mapped_dotted_names: + return + + # record a use for the left-most node in the attribute access chain + self.uses[before_dot.id].append((before_dot, self.current_scope)) + + # any other case is invalid, such as `Foo[...][...]` + else: + return + + # visit the wrapped annotations to update the mapped names + self.sqlalchemy_annotation_visitor.visit(node.slice) + + class InjectorMixin: """ Contains the necessary logic for injector (https://github.com/python-injector/injector) support. @@ -569,7 +808,46 @@ def lookup(self, symbol_name: str, use: HasPosition | None = None, runtime_only: return parent.lookup(symbol_name, use, runtime_only) -class ImportVisitor(DunderAllMixin, AttrsMixin, InjectorMixin, FastAPIMixin, PydanticMixin, ast.NodeVisitor): +class ImportAnnotationVisitor(AnnotationVisitor): + """Map all annotations on an AST node.""" + + def __init__(self) -> None: + #: All type annotations in the file, without quotes around them + self.unwrapped_annotations: list[UnwrappedAnnotation] = [] + + #: All type annotations in the file, with quotes around them + self.wrapped_annotations: list[WrappedAnnotation] = [] + + def visit( + self, node: ast.AST, scope: Scope | None = None, type: Literal['annotation', 'alias', 'new-alias'] | None = None + ) -> None: + """Visit the node with the given scope and annotation type.""" + if scope is not None: + self.scope = scope + if type is not None: + self.type = type + super().visit(node) + + def visit_annotation_name(self, node: ast.Name) -> None: + """Register unwrapped annotation.""" + setattr(node, ANNOTATION_PROPERTY, True) + self.unwrapped_annotations.append( + UnwrappedAnnotation(node.lineno, node.col_offset, node.id, self.scope, self.type) + ) + + def visit_annotation_string(self, node: ast.Constant) -> None: + """Register wrapped annotation.""" + setattr(node, ANNOTATION_PROPERTY, True) + self.wrapped_annotations.append( + WrappedAnnotation( + node.lineno, node.col_offset, node.value, set(NAME_RE.findall(node.value)), self.scope, self.type + ) + ) + + +class ImportVisitor( + DunderAllMixin, AttrsMixin, InjectorMixin, FastAPIMixin, PydanticMixin, SQLAlchemyMixin, ast.NodeVisitor +): """Map all imports outside of type-checking blocks.""" #: The currently active scope @@ -581,6 +859,8 @@ def __init__( pydantic_enabled: bool, fastapi_enabled: bool, fastapi_dependency_support_enabled: bool, + sqlalchemy_enabled: bool, + sqlalchemy_mapped_dotted_names: list[str], injector_enabled: bool, cattrs_enabled: bool, pydantic_enabled_baseclass_passlist: list[str], @@ -593,6 +873,10 @@ def __init__( self.fastapi_enabled = fastapi_enabled self.fastapi_dependency_support_enabled = fastapi_dependency_support_enabled self.cattrs_enabled = cattrs_enabled + self.sqlalchemy_enabled = sqlalchemy_enabled + self.sqlalchemy_mapped_dotted_names = sqlalchemy_default_mapped_dotted_names | set( + sqlalchemy_mapped_dotted_names + ) self.injector_enabled = injector_enabled self.pydantic_enabled_baseclass_passlist = pydantic_enabled_baseclass_passlist self.pydantic_validate_arguments_import_name = None @@ -619,11 +903,8 @@ def __init__( #: List of all names and ids, except type declarations self.uses: dict[str, list[tuple[ast.AST, Scope]]] = defaultdict(list) - #: All type annotations in the file, without quotes around them - self.unwrapped_annotations: list[UnwrappedAnnotation] = [] - - #: All type annotations in the file, with quotes around them - self.wrapped_annotations: list[WrappedAnnotation] = [] + #: Handles logic of visiting annotation nodes + self.annotation_visitor = ImportAnnotationVisitor() #: Whether there is a `from __futures__ import annotations` is present in the file self.futures_annotation: Optional[bool] = None @@ -667,6 +948,16 @@ def change_scope(self, scope: Scope) -> Iterator[None]: yield self.current_scope = old_scope + @property + def unwrapped_annotations(self) -> list[UnwrappedAnnotation]: + """All type annotations in the file, without quotes around them.""" + return self.annotation_visitor.unwrapped_annotations + + @property + def wrapped_annotations(self) -> list[WrappedAnnotation]: + """All type annotations in the file, with quotes around them.""" + return self.annotation_visitor.wrapped_annotations + @property def typing_module_name(self) -> str: """ @@ -939,10 +1230,12 @@ def add_import(self, node: Import) -> None: # noqa: C901 def visit_Import(self, node: ast.Import) -> None: """Append objects to our import map.""" + super().visit_Import(node) self.add_import(node) def visit_ImportFrom(self, node: ast.ImportFrom) -> None: """Append objects to our import map.""" + super().visit_ImportFrom(node) self.add_import(node) def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: @@ -1022,34 +1315,7 @@ def add_annotation( self, node: ast.AST, scope: Scope, type: Literal['annotation', 'alias', 'new-alias'] = 'annotation' ) -> None: """Map all annotations on an AST node.""" - if node is None: - return - if isinstance(node, ast.BinOp): - if not isinstance(node.op, ast.BitOr): - return - self.add_annotation(node.left, scope, type) - self.add_annotation(node.right, scope, type) - elif (py38 and isinstance(node, Index)) or isinstance(node, ast.Attribute): - self.add_annotation(node.value, scope, type) - elif isinstance(node, ast.Subscript): - self.add_annotation(node.value, scope, type) - if getattr(node.value, 'id', '') != 'Literal': - self.add_annotation(node.slice, scope, type) - elif isinstance(node, (ast.Tuple, ast.List)): - for n in node.elts: - self.add_annotation(n, scope, type) - elif isinstance(node, ast.Constant) and isinstance(node.value, str): - # Register annotation value - setattr(node, ANNOTATION_PROPERTY, True) - self.wrapped_annotations.append( - WrappedAnnotation( - node.lineno, node.col_offset, node.value, set(NAME_RE.findall(node.value)), scope, type - ) - ) - elif isinstance(node, ast.Name): - # Register annotation value - setattr(node, ANNOTATION_PROPERTY, True) - self.unwrapped_annotations.append(UnwrappedAnnotation(node.lineno, node.col_offset, node.id, scope, type)) + self.annotation_visitor.visit(node, scope, type) @staticmethod def set_child_node_attribute(node: Any, attr: str, val: Any) -> Any: @@ -1078,6 +1344,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None: an annotation as well, but we have to keep in mind that the RHS will not automatically become a ForwardRef with a future import, like a true annotation would. """ + super().visit_AnnAssign(node) self.add_annotation(node.annotation, self.current_scope) if node.value is None: @@ -1494,6 +1761,8 @@ def __init__(self, node: ast.Module, options: Optional[Namespace]) -> None: exempt_modules = getattr(options, 'type_checking_exempt_modules', []) pydantic_enabled = getattr(options, 'type_checking_pydantic_enabled', False) pydantic_enabled_baseclass_passlist = getattr(options, 'type_checking_pydantic_enabled_baseclass_passlist', []) + sqlalchemy_enabled = getattr(options, 'type_checking_sqlalchemy_enabled', False) + sqlalchemy_mapped_dotted_names = getattr(options, 'type_checking_sqlalchemy_mapped_dotted_names', []) fastapi_enabled = getattr(options, 'type_checking_fastapi_enabled', False) fastapi_dependency_support_enabled = getattr(options, 'type_checking_fastapi_dependency_support_enabled', False) cattrs_enabled = getattr(options, 'type_checking_cattrs_enabled', False) @@ -1514,6 +1783,8 @@ def __init__(self, node: ast.Module, options: Optional[Namespace]) -> None: fastapi_enabled=fastapi_enabled, cattrs_enabled=cattrs_enabled, exempt_modules=exempt_modules, + sqlalchemy_enabled=sqlalchemy_enabled, + sqlalchemy_mapped_dotted_names=sqlalchemy_mapped_dotted_names, fastapi_dependency_support_enabled=fastapi_dependency_support_enabled, pydantic_enabled_baseclass_passlist=pydantic_enabled_baseclass_passlist, injector_enabled=injector_enabled, @@ -1545,7 +1816,7 @@ def unused_imports(self) -> Flake8Generator: Classified.BUILTIN: (self.visitor.built_in_imports, TC003), } - unused_imports = set(self.visitor.imports) - self.visitor.names + unused_imports = set(self.visitor.imports) - self.visitor.names - self.visitor.mapped_names used_imports = set(self.visitor.imports) - unused_imports already_imported_modules = [self.visitor.imports[name].module for name in used_imports] annotation_names = [n for i in self.visitor.wrapped_annotations for n in i.names] + [ @@ -1586,7 +1857,7 @@ def used_type_checking_symbols(self) -> Flake8Generator: # only imports and definitions can be moved around continue - if getattr(use, DUNDER_ALL_PROPERTY, False): + if isinstance(use, ast.Constant): # this is actually a quoted name, so it should exist # as long as it's in the scope at all, we don't need # to take the position into account diff --git a/flake8_type_checking/constants.py b/flake8_type_checking/constants.py index 8436af5..5de9c22 100644 --- a/flake8_type_checking/constants.py +++ b/flake8_type_checking/constants.py @@ -6,7 +6,6 @@ ATTRIBUTE_PROPERTY = '_flake8-type-checking__parent' ANNOTATION_PROPERTY = '_flake8-type-checking__is_annotation' -DUNDER_ALL_PROPERTY = '_flake8-type-checking__in__all__' NAME_RE = re.compile(r'(? None: # pragma: no cover default=[], help='Names of base classes to not treat as pydantic models. For example `NamedTuple` or `TypedDict`.', ) + option_manager.add_option( + '--type-checking-sqlalchemy-enabled', + action='store_true', + parse_from_config=True, + default=False, + help='Prevent flagging of annotations on mapped attributes.', + ) + option_manager.add_option( + '--type-checking-sqlalchemy-mapped-dotted-names', + comma_separated_list=True, + parse_from_config=True, + default=[], + help='Dotted names of additional Mapped subclasses. For example `a.MyMapped, a.b.MyMapped`', + ) option_manager.add_option( '--type-checking-fastapi-enabled', action='store_true', diff --git a/setup.cfg b/setup.cfg index 5f8a1ac..bc210df 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,10 +12,12 @@ ignore= # W503 line break before binary operator W503, # New black update seem to be incompatible with the next 3 + # E203 whitespace before ':' # E231 missing whitespace after ':' # E241 multiple spaces after ':' # E272 multiple spaces before keyword # E271 multiple spaces after keyword + E203, E231, E241, E272, diff --git a/tests/conftest.py b/tests/conftest.py index a941f98..4b43422 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,6 +41,8 @@ def _get_error(example: str, *, error_code_filter: Optional[str] = None, **kwarg mock_options.type_checking_fastapi_enabled = False mock_options.type_checking_fastapi_dependency_support_enabled = False mock_options.type_checking_pydantic_enabled_baseclass_passlist = [] + mock_options.type_checking_sqlalchemy_enabled = False + mock_options.type_checking_sqlalchemy_mapped_dotted_names = [] mock_options.type_checking_injector_enabled = False mock_options.type_checking_strict = False # kwarg overrides diff --git a/tests/test_import_visitors.py b/tests/test_import_visitors.py index 781f4d5..5ee6792 100644 --- a/tests/test_import_visitors.py +++ b/tests/test_import_visitors.py @@ -19,6 +19,8 @@ def _visit(example: str) -> ImportVisitor: fastapi_enabled=False, fastapi_dependency_support_enabled=False, cattrs_enabled=False, + sqlalchemy_enabled=False, + sqlalchemy_mapped_dotted_names=[], injector_enabled=False, pydantic_enabled_baseclass_passlist=[], ) diff --git a/tests/test_name_visitor.py b/tests/test_name_visitor.py index ee79fd8..c282108 100644 --- a/tests/test_name_visitor.py +++ b/tests/test_name_visitor.py @@ -19,6 +19,8 @@ def _get_names(example: str) -> Set[str]: fastapi_enabled=False, fastapi_dependency_support_enabled=False, cattrs_enabled=False, + sqlalchemy_enabled=False, + sqlalchemy_mapped_dotted_names=[], injector_enabled=False, pydantic_enabled_baseclass_passlist=[], ) diff --git a/tests/test_sqlalchemy.py b/tests/test_sqlalchemy.py new file mode 100644 index 0000000..b8e1079 --- /dev/null +++ b/tests/test_sqlalchemy.py @@ -0,0 +1,163 @@ +""" +This file tests SQLAlchemy support. + +See https://github.com/snok/flake8-type-checking/issues/178 +for discussion on the implementation. +""" + +import textwrap + +import pytest + +from flake8_type_checking.constants import TC002, TC004 +from tests.conftest import _get_error + + +@pytest.mark.parametrize( + ('enabled', 'expected'), + [ + (True, set()), + (False, {'2:0 ' + TC002.format(module='foo.Bar'), '3:0 ' + TC002.format(module='sqlalchemy.orm.Mapped')}), + ], +) +def test_simple_mapped_use(enabled, expected): + """ + Mapped itself must be available at runtime and the inner type may or + may not need to be available at runtime. + """ + example = textwrap.dedent(''' + from foo import Bar + from sqlalchemy.orm import Mapped + + class User: + x: Mapped[Bar] + ''') + assert _get_error(example, error_code_filter='TC002', type_checking_sqlalchemy_enabled=enabled) == expected + + +@pytest.mark.parametrize( + ('name', 'expected'), + [ + ('Mapped', set()), + ('DynamicMapped', set()), + ('WriteOnlyMapped', set()), + ( + 'NotMapped', + {'2:0 ' + TC002.format(module='foo.Bar'), '3:0 ' + TC002.format(module='sqlalchemy.orm.NotMapped')}, + ), + ], +) +def test_default_mapped_names(name, expected): + """Check the three default names and a bogus name.""" + example = textwrap.dedent(f''' + from foo import Bar + from sqlalchemy.orm import {name} + + class User: + x: {name}[Bar] + ''') + assert _get_error(example, error_code_filter='TC002', type_checking_sqlalchemy_enabled=True) == expected + + +def test_mapped_with_circular_forward_reference(): + """ + Mapped must still be available at runtime even with forward references + to a different model. + """ + example = textwrap.dedent(''' + from sqlalchemy.orm import Mapped + if TYPE_CHECKING: + from .address import Address + + class User: + address: Mapped['Address'] + ''') + assert _get_error(example, error_code_filter='TC002', type_checking_sqlalchemy_enabled=True) == set() + + +def test_mapped_use_without_runtime_import(): + """ + Mapped must be available at runtime, so even if it is inside a wrapped annotation + we should raise a TC004 for Mapped but not for Bar + """ + example = textwrap.dedent(''' + if TYPE_CHECKING: + from foo import Bar + from sqlalchemy.orm import Mapped + + class User: + created: 'Mapped[Bar]' + ''') + assert _get_error(example, error_code_filter='TC004', type_checking_sqlalchemy_enabled=True) == { + '4:0 ' + TC004.format(module='Mapped') + } + + +def test_custom_mapped_dotted_names_unwrapped(): + """ + Check a couple of custom dotted names and a bogus one. This also tests the + various styles of imports + """ + example = textwrap.dedent(''' + import a + import a.b as ab + from a import b + from a import MyMapped + from a.b import MyMapped as B + from a import Bogus + from foo import Bar + + class User: + t: MyMapped[Bar] + u: B[Bar] + v: Bogus[Bar] + w: a.MyMapped[Bar] + x: b.MyMapped[Bar] + y: a.b.MyMapped[Bar] + z: ab.MyMapped[Bar] + ''') + assert _get_error( + example, + error_code_filter='TC002', + type_checking_strict=True, # ignore overlapping imports for this test + type_checking_sqlalchemy_enabled=True, + type_checking_sqlalchemy_mapped_dotted_names=['a.MyMapped', 'a.b.MyMapped'], + ) == {'7:0 ' + TC002.format(module='a.Bogus')} + + +def test_custom_mapped_dotted_names_wrapped(): + """ + Same as the unwrapped test but with wrapped annotations. This should generate + a bunch of TC004 errors for the uses of mapped that should be available at runtime. + """ + example = textwrap.dedent(''' + if TYPE_CHECKING: + import a + import a.b as ab + from a import b + from a import MyMapped + from a.b import MyMapped as B + from a import Bogus + from foo import Bar + + class User: + t: 'MyMapped[Bar]' + u: 'B[Bar]' + v: 'Bogus[Bar]' + w: 'a.MyMapped[Bar]' + x: 'b.MyMapped[Bar]' + y: 'a.b.MyMapped[Bar]' + z: 'ab.MyMapped[Bar]' + ''') + assert _get_error( + example, + error_code_filter='TC004', + type_checking_sqlalchemy_enabled=True, + type_checking_sqlalchemy_mapped_dotted_names=['a.MyMapped', 'a.b.MyMapped'], + ) == { + '3:0 ' + TC004.format(module='a'), + '4:0 ' + TC004.format(module='ab'), + '5:0 ' + TC004.format(module='b'), + '6:0 ' + TC004.format(module='MyMapped'), + '7:0 ' + TC004.format(module='B'), + } From d7a40c7a4b84bf8e604d8fd610feeb057e23cc14 Mon Sep 17 00:00:00 2001 From: David Salvisberg Date: Thu, 7 Dec 2023 13:05:11 +0100 Subject: [PATCH 2/2] Update README.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sondre Lillebø Gundersen --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 13d9c34..9661932 100644 --- a/README.md +++ b/README.md @@ -232,7 +232,7 @@ If you're using SQLAlchemy 2.0+, you can enable support. This will treat any `Mapped[...]` types as needed at runtime. It will also specially treat the enclosed type, since it may or may not need to be available at runtime depending on whether or not the enclosed -type is a model or not, since model can have circular dependencies. +type is a model or not, since models can have circular dependencies. - **name**: `type-checking-sqlalchemy-enabled` - **type**: `bool`