diff --git a/README.md b/README.md index d334554..9661932 100644 --- a/README.md +++ b/README.md @@ -226,6 +226,38 @@ Enabling dependency support will also enable FastAPI and Pydantic support. type-checking-fastapi-dependency-support-enabled = true # default false ``` +### SQLAlchemy 2.0+ support + +If you're using SQLAlchemy 2.0+, you can enable support. +This will treat any `Mapped[...]` types as needed at runtime. +It will also specially treat the enclosed type, since it may or may not +need to be available at runtime depending on whether or not the enclosed +type is a model or not, since models can have circular dependencies. + +- **name**: `type-checking-sqlalchemy-enabled` +- **type**: `bool` +```ini +type-checking-sqlalchemy-enabled = true # default false +``` + +### SQLAlchemy 2.0+ support mapped dotted names + +Since it's possible to create subclasses of `sqlalchemy.orm.Mapped` that +define some custom behavior for the mapped attribute, but otherwise still +behave like any other mapped attribute, i.e. the same runtime restrictions +apply it's possible to provide additional dotted names that should be treated +like subclasses of `Mapped`. By default we check for `sqlalchemy.orm.Mapped`, +`sqlalchemy.orm.DynamicMapped` and `sqlalchemy.orm.WriteOnlyMapped`. + +If there's more than one import path for the same `Mapped` subclass, then you +need to specify each of them as a separate dotted name. + +- **name**: `type-checking-sqlalchemy-mapped-dotted-names` +- **type**: `list` +```ini +type-checking-sqlalchemy-mapped-dotted-names = a.MyMapped, a.b.MyMapped # default [] +``` + ### Cattrs support If you're using the plugin in a project which uses `cattrs`, diff --git a/flake8_type_checking/checker.py b/flake8_type_checking/checker.py index 7a01ddd..06fd76d 100644 --- a/flake8_type_checking/checker.py +++ b/flake8_type_checking/checker.py @@ -4,6 +4,7 @@ import fnmatch import os import sys +from abc import ABC, abstractmethod from ast import Index, literal_eval from collections import defaultdict from contextlib import contextmanager, suppress @@ -19,7 +20,6 @@ ATTRIBUTE_PROPERTY, ATTRS_DECORATORS, ATTRS_IMPORTS, - DUNDER_ALL_PROPERTY, NAME_RE, TC001, TC002, @@ -36,6 +36,7 @@ TC201, builtin_names, py38, + sqlalchemy_default_mapped_dotted_names, ) try: @@ -67,6 +68,41 @@ def ast_unparse(node: ast.AST) -> str: ) +class AnnotationVisitor(ABC): + """Simplified node visitor for traversing annotations.""" + + @abstractmethod + def visit_annotation_name(self, node: ast.Name) -> None: + """Visit a name inside an annotation.""" + + @abstractmethod + def visit_annotation_string(self, node: ast.Constant) -> None: + """Visit a string constant inside an annotation.""" + + def visit(self, node: ast.AST) -> None: + """Visit relevant child nodes on an annotation.""" + if node is None: + return + if isinstance(node, ast.BinOp): + if not isinstance(node.op, ast.BitOr): + return + self.visit(node.left) + self.visit(node.right) + elif (py38 and isinstance(node, Index)) or isinstance(node, ast.Attribute): + self.visit(node.value) + elif isinstance(node, ast.Subscript): + self.visit(node.value) + if getattr(node.value, 'id', '') != 'Literal': + self.visit(node.slice) + elif isinstance(node, (ast.Tuple, ast.List)): + for n in node.elts: + self.visit(n) + elif isinstance(node, ast.Constant) and isinstance(node.value, str): + self.visit_annotation_string(node) + elif isinstance(node, ast.Name): + self.visit_annotation_name(node) + + class AttrsMixin: """ Contains necessary logic for cattrs + attrs support. @@ -152,6 +188,7 @@ def generic_visit(self, node: ast.AST) -> None: # noqa: D102 ... def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self.__all___assignments: list[tuple[int, int]] = [] def in___all___declaration(self, node: ast.Constant) -> bool: @@ -201,8 +238,9 @@ def visit_Constant(self, node: ast.Constant) -> ast.Constant: """Map constant as use, if we're inside an __all__ declaration.""" if self.in___all___declaration(node): # for these it doesn't matter where they are declared, the symbol - # just needs to be available in global scope anywhere - setattr(node, DUNDER_ALL_PROPERTY, True) + # just needs to be available in global scope anywhere, we handle + # this by special casing `ast.Constant` when we look for used type + # checking symbols self.uses[node.value].append((node, self.current_scope)) return node @@ -246,6 +284,207 @@ def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None: self.visit(argument.annotation) +class SQLAlchemyAnnotationVisitor(AnnotationVisitor): + """Adds any names in the annotation to mapped names.""" + + def __init__(self, mapped_names: set[str]) -> None: + self.mapped_names = mapped_names + + def visit_annotation_name(self, node: ast.Name) -> None: + """Add name to mapped names.""" + self.mapped_names.add(node.id) + + def visit_annotation_string(self, node: ast.Constant) -> None: + """Add all the names in the string to mapped names.""" + self.mapped_names.update(NAME_RE.findall(node.value)) + + +class SQLAlchemyMixin: + """ + Contains the necessary logic for SQLAlchemy (https://www.sqlalchemy.org/) support. + + For mapped attributes we don't know whether or not the enclosed type needs to be + available at runtime, since it may or may not be a model. So we treat it like a + runtime use for the purposes of TC001/TC002/TC003 but not anywhere else. + + This supports `sqlalchemy.orm.Mapped`, `sqlalchemy.orm.DynamicMapped` and + `sqlalchemy.orm.WriteOnlyMapped` by default, but can be extended to cover + additional user-defined subclassed of `sqlalchemy.orm.Mapped` using the + setting `type-checking-sqlalchemy-mapped-dotted-names`. + """ + + if TYPE_CHECKING: + sqlalchemy_enabled: bool + sqlalchemy_mapped_dotted_names: set[str] + current_scope: Scope + uses: dict[str, list[tuple[ast.AST, Scope]]] + + def visit(self, node: ast.AST) -> ast.AST: # noqa: D102 + ... + + def in_type_checking_block(self, lineno: int, col_offset: int) -> bool: # noqa: D102 + ... + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + #: Contains a set of all the `Mapped` names that have been imported + self.mapped_aliases: set[str] = set() + + #: Contains a dictionary of aliases to one of our dotted names containing one of the `Mapped` names + self.mapped_module_aliases: dict[str, str] = {} + + #: Contains a set of all names used inside `Mapped[...]` annotations + # These are treated like soft-uses, i.e. we don't know if it will be + # used at runtime or not + self.mapped_names: set[str] = set() + + #: Used for visiting annotations + self.sqlalchemy_annotation_visitor = SQLAlchemyAnnotationVisitor(self.mapped_names) + + def visit_Import(self, node: ast.Import) -> None: + """Record `Mapped` module aliases.""" + if not self.sqlalchemy_enabled: + return + + for name in node.names: + # we don't need to map anything + if name.asname is None: + continue + + prefix = name.name + '.' + for dotted_name in self.sqlalchemy_mapped_dotted_names: + if dotted_name.startswith(prefix): + self.mapped_module_aliases[name.asname] = name.name + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + """Record `Mapped` aliases.""" + if not self.sqlalchemy_enabled: + return + + # we don't try to deal with relative imports + if node.module is None or node.level != 0: + return + + prefix = node.module + '.' + for dotted_name in self.sqlalchemy_mapped_dotted_names: + if not dotted_name.startswith(prefix): + continue + + suffix = dotted_name[len(prefix) :] + if '.' in suffix: + # we may need to record a mapped module + for name in node.names: + if suffix.startswith(name.name + '.'): + self.mapped_module_aliases[name.asname or name.name] = node.module + elif name.name == '*': + # in this case we can assume that the next segment of the dotted + # name has been imported + self.mapped_module_aliases[suffix.split('.', 1)[0]] = node.module + continue + + # we may need to record a mapped name + for name in node.names: + if suffix == name.name: + self.mapped_aliases.add(name.asname or name.name) + elif name.name == '*': + # in this case we can assume it has been imported + self.mapped_aliases.add(suffix) + + def visit_AnnAssign(self, node: ast.AnnAssign) -> None: + """Remove all annotations assigments.""" + if ( + self.sqlalchemy_enabled + # We only need to special case runtime use of `Mapped` + and not self.in_type_checking_block(node.lineno, node.col_offset) + ): + self.handle_sqlalchemy_annotation(node.annotation) + + def handle_sqlalchemy_annotation(self, node: ast.AST) -> None: + """ + Record all the mapped names inside the annotations. + + If the annotation is an `ast.Constant` that starts with one of our + `Mapped` names, then we will record a runtime use of that symbol, + since we know `Mapped` always needs to resolve. + """ + if isinstance(node, ast.Constant): + # we only need to handle annotations like `"Mapped[...]"` + if not isinstance(node.value, str) or '[' not in node.value: + return + + annotation = node.value.strip() + if not annotation.endswith(']'): + return + + # if we ever do more sophisticated parsing of text annotations + # then we would want to strip the trailing `]` from inner, but + # with our simple parsing we don't care + mapped_name, inner = annotation.split('[', 1) + if mapped_name in self.mapped_aliases: + # record a use for the name + self.uses[mapped_name].append((node, self.current_scope)) + + elif mapped_name in self.sqlalchemy_mapped_dotted_names: + # record a use for the left-most part of the dotted name + self.uses[mapped_name.split('.', 1)[0]].append((node, self.current_scope)) + + elif '.' in annotation: + # try to resolve to a mapped module alias + aliased, mapped_name = mapped_name.split('.', 1) + module = self.mapped_module_aliases.get(aliased) + if module is None or f'{module}.{mapped_name}' not in self.sqlalchemy_mapped_dotted_names: + return + + # record a use for the alias + self.uses[aliased].append((node, self.current_scope)) + + # add all names contained in the inner part of the annotation + # since this is not as strict as an actual runtime use, we don't + # care if we record too much here + self.mapped_names.update(NAME_RE.findall(inner)) + return + + # we only need to handle annotations like `Mapped[...]` + if not isinstance(node, ast.Subscript): + return + + # simple case only needs to check mapped_aliases + if isinstance(node.value, ast.Name): + if node.value.id not in self.mapped_aliases: + return + + # record a use for the name + self.uses[node.value.id].append((node.value, self.current_scope)) + + # complex case for dotted names + elif isinstance(node.value, ast.Attribute): + dotted_name = node.value.attr + before_dot = node.value.value + while isinstance(before_dot, ast.Attribute): + dotted_name = f'{before_dot.attr}.{dotted_name}' + before_dot = before_dot.value + # there should be no subscripts between the attributes + if not isinstance(before_dot, ast.Name): + return + + # map the module if it's mapped otherwise use it as is + module = self.mapped_module_aliases.get(before_dot.id, before_dot.id) + dotted_name = f'{module}.{dotted_name}' + if dotted_name not in self.sqlalchemy_mapped_dotted_names: + return + + # record a use for the left-most node in the attribute access chain + self.uses[before_dot.id].append((before_dot, self.current_scope)) + + # any other case is invalid, such as `Foo[...][...]` + else: + return + + # visit the wrapped annotations to update the mapped names + self.sqlalchemy_annotation_visitor.visit(node.slice) + + class InjectorMixin: """ Contains the necessary logic for injector (https://github.com/python-injector/injector) support. @@ -569,7 +808,46 @@ def lookup(self, symbol_name: str, use: HasPosition | None = None, runtime_only: return parent.lookup(symbol_name, use, runtime_only) -class ImportVisitor(DunderAllMixin, AttrsMixin, InjectorMixin, FastAPIMixin, PydanticMixin, ast.NodeVisitor): +class ImportAnnotationVisitor(AnnotationVisitor): + """Map all annotations on an AST node.""" + + def __init__(self) -> None: + #: All type annotations in the file, without quotes around them + self.unwrapped_annotations: list[UnwrappedAnnotation] = [] + + #: All type annotations in the file, with quotes around them + self.wrapped_annotations: list[WrappedAnnotation] = [] + + def visit( + self, node: ast.AST, scope: Scope | None = None, type: Literal['annotation', 'alias', 'new-alias'] | None = None + ) -> None: + """Visit the node with the given scope and annotation type.""" + if scope is not None: + self.scope = scope + if type is not None: + self.type = type + super().visit(node) + + def visit_annotation_name(self, node: ast.Name) -> None: + """Register unwrapped annotation.""" + setattr(node, ANNOTATION_PROPERTY, True) + self.unwrapped_annotations.append( + UnwrappedAnnotation(node.lineno, node.col_offset, node.id, self.scope, self.type) + ) + + def visit_annotation_string(self, node: ast.Constant) -> None: + """Register wrapped annotation.""" + setattr(node, ANNOTATION_PROPERTY, True) + self.wrapped_annotations.append( + WrappedAnnotation( + node.lineno, node.col_offset, node.value, set(NAME_RE.findall(node.value)), self.scope, self.type + ) + ) + + +class ImportVisitor( + DunderAllMixin, AttrsMixin, InjectorMixin, FastAPIMixin, PydanticMixin, SQLAlchemyMixin, ast.NodeVisitor +): """Map all imports outside of type-checking blocks.""" #: The currently active scope @@ -581,6 +859,8 @@ def __init__( pydantic_enabled: bool, fastapi_enabled: bool, fastapi_dependency_support_enabled: bool, + sqlalchemy_enabled: bool, + sqlalchemy_mapped_dotted_names: list[str], injector_enabled: bool, cattrs_enabled: bool, pydantic_enabled_baseclass_passlist: list[str], @@ -593,6 +873,10 @@ def __init__( self.fastapi_enabled = fastapi_enabled self.fastapi_dependency_support_enabled = fastapi_dependency_support_enabled self.cattrs_enabled = cattrs_enabled + self.sqlalchemy_enabled = sqlalchemy_enabled + self.sqlalchemy_mapped_dotted_names = sqlalchemy_default_mapped_dotted_names | set( + sqlalchemy_mapped_dotted_names + ) self.injector_enabled = injector_enabled self.pydantic_enabled_baseclass_passlist = pydantic_enabled_baseclass_passlist self.pydantic_validate_arguments_import_name = None @@ -619,11 +903,8 @@ def __init__( #: List of all names and ids, except type declarations self.uses: dict[str, list[tuple[ast.AST, Scope]]] = defaultdict(list) - #: All type annotations in the file, without quotes around them - self.unwrapped_annotations: list[UnwrappedAnnotation] = [] - - #: All type annotations in the file, with quotes around them - self.wrapped_annotations: list[WrappedAnnotation] = [] + #: Handles logic of visiting annotation nodes + self.annotation_visitor = ImportAnnotationVisitor() #: Whether there is a `from __futures__ import annotations` is present in the file self.futures_annotation: Optional[bool] = None @@ -667,6 +948,16 @@ def change_scope(self, scope: Scope) -> Iterator[None]: yield self.current_scope = old_scope + @property + def unwrapped_annotations(self) -> list[UnwrappedAnnotation]: + """All type annotations in the file, without quotes around them.""" + return self.annotation_visitor.unwrapped_annotations + + @property + def wrapped_annotations(self) -> list[WrappedAnnotation]: + """All type annotations in the file, with quotes around them.""" + return self.annotation_visitor.wrapped_annotations + @property def typing_module_name(self) -> str: """ @@ -939,10 +1230,12 @@ def add_import(self, node: Import) -> None: # noqa: C901 def visit_Import(self, node: ast.Import) -> None: """Append objects to our import map.""" + super().visit_Import(node) self.add_import(node) def visit_ImportFrom(self, node: ast.ImportFrom) -> None: """Append objects to our import map.""" + super().visit_ImportFrom(node) self.add_import(node) def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: @@ -1022,34 +1315,7 @@ def add_annotation( self, node: ast.AST, scope: Scope, type: Literal['annotation', 'alias', 'new-alias'] = 'annotation' ) -> None: """Map all annotations on an AST node.""" - if node is None: - return - if isinstance(node, ast.BinOp): - if not isinstance(node.op, ast.BitOr): - return - self.add_annotation(node.left, scope, type) - self.add_annotation(node.right, scope, type) - elif (py38 and isinstance(node, Index)) or isinstance(node, ast.Attribute): - self.add_annotation(node.value, scope, type) - elif isinstance(node, ast.Subscript): - self.add_annotation(node.value, scope, type) - if getattr(node.value, 'id', '') != 'Literal': - self.add_annotation(node.slice, scope, type) - elif isinstance(node, (ast.Tuple, ast.List)): - for n in node.elts: - self.add_annotation(n, scope, type) - elif isinstance(node, ast.Constant) and isinstance(node.value, str): - # Register annotation value - setattr(node, ANNOTATION_PROPERTY, True) - self.wrapped_annotations.append( - WrappedAnnotation( - node.lineno, node.col_offset, node.value, set(NAME_RE.findall(node.value)), scope, type - ) - ) - elif isinstance(node, ast.Name): - # Register annotation value - setattr(node, ANNOTATION_PROPERTY, True) - self.unwrapped_annotations.append(UnwrappedAnnotation(node.lineno, node.col_offset, node.id, scope, type)) + self.annotation_visitor.visit(node, scope, type) @staticmethod def set_child_node_attribute(node: Any, attr: str, val: Any) -> Any: @@ -1078,6 +1344,7 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None: an annotation as well, but we have to keep in mind that the RHS will not automatically become a ForwardRef with a future import, like a true annotation would. """ + super().visit_AnnAssign(node) self.add_annotation(node.annotation, self.current_scope) if node.value is None: @@ -1494,6 +1761,8 @@ def __init__(self, node: ast.Module, options: Optional[Namespace]) -> None: exempt_modules = getattr(options, 'type_checking_exempt_modules', []) pydantic_enabled = getattr(options, 'type_checking_pydantic_enabled', False) pydantic_enabled_baseclass_passlist = getattr(options, 'type_checking_pydantic_enabled_baseclass_passlist', []) + sqlalchemy_enabled = getattr(options, 'type_checking_sqlalchemy_enabled', False) + sqlalchemy_mapped_dotted_names = getattr(options, 'type_checking_sqlalchemy_mapped_dotted_names', []) fastapi_enabled = getattr(options, 'type_checking_fastapi_enabled', False) fastapi_dependency_support_enabled = getattr(options, 'type_checking_fastapi_dependency_support_enabled', False) cattrs_enabled = getattr(options, 'type_checking_cattrs_enabled', False) @@ -1514,6 +1783,8 @@ def __init__(self, node: ast.Module, options: Optional[Namespace]) -> None: fastapi_enabled=fastapi_enabled, cattrs_enabled=cattrs_enabled, exempt_modules=exempt_modules, + sqlalchemy_enabled=sqlalchemy_enabled, + sqlalchemy_mapped_dotted_names=sqlalchemy_mapped_dotted_names, fastapi_dependency_support_enabled=fastapi_dependency_support_enabled, pydantic_enabled_baseclass_passlist=pydantic_enabled_baseclass_passlist, injector_enabled=injector_enabled, @@ -1545,7 +1816,7 @@ def unused_imports(self) -> Flake8Generator: Classified.BUILTIN: (self.visitor.built_in_imports, TC003), } - unused_imports = set(self.visitor.imports) - self.visitor.names + unused_imports = set(self.visitor.imports) - self.visitor.names - self.visitor.mapped_names used_imports = set(self.visitor.imports) - unused_imports already_imported_modules = [self.visitor.imports[name].module for name in used_imports] annotation_names = [n for i in self.visitor.wrapped_annotations for n in i.names] + [ @@ -1586,7 +1857,7 @@ def used_type_checking_symbols(self) -> Flake8Generator: # only imports and definitions can be moved around continue - if getattr(use, DUNDER_ALL_PROPERTY, False): + if isinstance(use, ast.Constant): # this is actually a quoted name, so it should exist # as long as it's in the scope at all, we don't need # to take the position into account diff --git a/flake8_type_checking/constants.py b/flake8_type_checking/constants.py index 8436af5..5de9c22 100644 --- a/flake8_type_checking/constants.py +++ b/flake8_type_checking/constants.py @@ -6,7 +6,6 @@ ATTRIBUTE_PROPERTY = '_flake8-type-checking__parent' ANNOTATION_PROPERTY = '_flake8-type-checking__is_annotation' -DUNDER_ALL_PROPERTY = '_flake8-type-checking__in__all__' NAME_RE = re.compile(r'(? None: # pragma: no cover default=[], help='Names of base classes to not treat as pydantic models. For example `NamedTuple` or `TypedDict`.', ) + option_manager.add_option( + '--type-checking-sqlalchemy-enabled', + action='store_true', + parse_from_config=True, + default=False, + help='Prevent flagging of annotations on mapped attributes.', + ) + option_manager.add_option( + '--type-checking-sqlalchemy-mapped-dotted-names', + comma_separated_list=True, + parse_from_config=True, + default=[], + help='Dotted names of additional Mapped subclasses. For example `a.MyMapped, a.b.MyMapped`', + ) option_manager.add_option( '--type-checking-fastapi-enabled', action='store_true', diff --git a/setup.cfg b/setup.cfg index 5f8a1ac..bc210df 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,10 +12,12 @@ ignore= # W503 line break before binary operator W503, # New black update seem to be incompatible with the next 3 + # E203 whitespace before ':' # E231 missing whitespace after ':' # E241 multiple spaces after ':' # E272 multiple spaces before keyword # E271 multiple spaces after keyword + E203, E231, E241, E272, diff --git a/tests/conftest.py b/tests/conftest.py index a941f98..4b43422 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,6 +41,8 @@ def _get_error(example: str, *, error_code_filter: Optional[str] = None, **kwarg mock_options.type_checking_fastapi_enabled = False mock_options.type_checking_fastapi_dependency_support_enabled = False mock_options.type_checking_pydantic_enabled_baseclass_passlist = [] + mock_options.type_checking_sqlalchemy_enabled = False + mock_options.type_checking_sqlalchemy_mapped_dotted_names = [] mock_options.type_checking_injector_enabled = False mock_options.type_checking_strict = False # kwarg overrides diff --git a/tests/test_import_visitors.py b/tests/test_import_visitors.py index 781f4d5..5ee6792 100644 --- a/tests/test_import_visitors.py +++ b/tests/test_import_visitors.py @@ -19,6 +19,8 @@ def _visit(example: str) -> ImportVisitor: fastapi_enabled=False, fastapi_dependency_support_enabled=False, cattrs_enabled=False, + sqlalchemy_enabled=False, + sqlalchemy_mapped_dotted_names=[], injector_enabled=False, pydantic_enabled_baseclass_passlist=[], ) diff --git a/tests/test_name_visitor.py b/tests/test_name_visitor.py index ee79fd8..c282108 100644 --- a/tests/test_name_visitor.py +++ b/tests/test_name_visitor.py @@ -19,6 +19,8 @@ def _get_names(example: str) -> Set[str]: fastapi_enabled=False, fastapi_dependency_support_enabled=False, cattrs_enabled=False, + sqlalchemy_enabled=False, + sqlalchemy_mapped_dotted_names=[], injector_enabled=False, pydantic_enabled_baseclass_passlist=[], ) diff --git a/tests/test_sqlalchemy.py b/tests/test_sqlalchemy.py new file mode 100644 index 0000000..b8e1079 --- /dev/null +++ b/tests/test_sqlalchemy.py @@ -0,0 +1,163 @@ +""" +This file tests SQLAlchemy support. + +See https://github.com/snok/flake8-type-checking/issues/178 +for discussion on the implementation. +""" + +import textwrap + +import pytest + +from flake8_type_checking.constants import TC002, TC004 +from tests.conftest import _get_error + + +@pytest.mark.parametrize( + ('enabled', 'expected'), + [ + (True, set()), + (False, {'2:0 ' + TC002.format(module='foo.Bar'), '3:0 ' + TC002.format(module='sqlalchemy.orm.Mapped')}), + ], +) +def test_simple_mapped_use(enabled, expected): + """ + Mapped itself must be available at runtime and the inner type may or + may not need to be available at runtime. + """ + example = textwrap.dedent(''' + from foo import Bar + from sqlalchemy.orm import Mapped + + class User: + x: Mapped[Bar] + ''') + assert _get_error(example, error_code_filter='TC002', type_checking_sqlalchemy_enabled=enabled) == expected + + +@pytest.mark.parametrize( + ('name', 'expected'), + [ + ('Mapped', set()), + ('DynamicMapped', set()), + ('WriteOnlyMapped', set()), + ( + 'NotMapped', + {'2:0 ' + TC002.format(module='foo.Bar'), '3:0 ' + TC002.format(module='sqlalchemy.orm.NotMapped')}, + ), + ], +) +def test_default_mapped_names(name, expected): + """Check the three default names and a bogus name.""" + example = textwrap.dedent(f''' + from foo import Bar + from sqlalchemy.orm import {name} + + class User: + x: {name}[Bar] + ''') + assert _get_error(example, error_code_filter='TC002', type_checking_sqlalchemy_enabled=True) == expected + + +def test_mapped_with_circular_forward_reference(): + """ + Mapped must still be available at runtime even with forward references + to a different model. + """ + example = textwrap.dedent(''' + from sqlalchemy.orm import Mapped + if TYPE_CHECKING: + from .address import Address + + class User: + address: Mapped['Address'] + ''') + assert _get_error(example, error_code_filter='TC002', type_checking_sqlalchemy_enabled=True) == set() + + +def test_mapped_use_without_runtime_import(): + """ + Mapped must be available at runtime, so even if it is inside a wrapped annotation + we should raise a TC004 for Mapped but not for Bar + """ + example = textwrap.dedent(''' + if TYPE_CHECKING: + from foo import Bar + from sqlalchemy.orm import Mapped + + class User: + created: 'Mapped[Bar]' + ''') + assert _get_error(example, error_code_filter='TC004', type_checking_sqlalchemy_enabled=True) == { + '4:0 ' + TC004.format(module='Mapped') + } + + +def test_custom_mapped_dotted_names_unwrapped(): + """ + Check a couple of custom dotted names and a bogus one. This also tests the + various styles of imports + """ + example = textwrap.dedent(''' + import a + import a.b as ab + from a import b + from a import MyMapped + from a.b import MyMapped as B + from a import Bogus + from foo import Bar + + class User: + t: MyMapped[Bar] + u: B[Bar] + v: Bogus[Bar] + w: a.MyMapped[Bar] + x: b.MyMapped[Bar] + y: a.b.MyMapped[Bar] + z: ab.MyMapped[Bar] + ''') + assert _get_error( + example, + error_code_filter='TC002', + type_checking_strict=True, # ignore overlapping imports for this test + type_checking_sqlalchemy_enabled=True, + type_checking_sqlalchemy_mapped_dotted_names=['a.MyMapped', 'a.b.MyMapped'], + ) == {'7:0 ' + TC002.format(module='a.Bogus')} + + +def test_custom_mapped_dotted_names_wrapped(): + """ + Same as the unwrapped test but with wrapped annotations. This should generate + a bunch of TC004 errors for the uses of mapped that should be available at runtime. + """ + example = textwrap.dedent(''' + if TYPE_CHECKING: + import a + import a.b as ab + from a import b + from a import MyMapped + from a.b import MyMapped as B + from a import Bogus + from foo import Bar + + class User: + t: 'MyMapped[Bar]' + u: 'B[Bar]' + v: 'Bogus[Bar]' + w: 'a.MyMapped[Bar]' + x: 'b.MyMapped[Bar]' + y: 'a.b.MyMapped[Bar]' + z: 'ab.MyMapped[Bar]' + ''') + assert _get_error( + example, + error_code_filter='TC004', + type_checking_sqlalchemy_enabled=True, + type_checking_sqlalchemy_mapped_dotted_names=['a.MyMapped', 'a.b.MyMapped'], + ) == { + '3:0 ' + TC004.format(module='a'), + '4:0 ' + TC004.format(module='ab'), + '5:0 ' + TC004.format(module='b'), + '6:0 ' + TC004.format(module='MyMapped'), + '7:0 ' + TC004.format(module='B'), + }