From 5324a577b11cd0d9379aaf462928afeb66dd83ec Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Sat, 20 Apr 2024 14:54:04 +0100 Subject: [PATCH] [`flake8-pyi`] Allow overloaded `__exit__` and `__aexit__` definitions (`PYI036`) --- .../test/fixtures/flake8_pyi/PYI036.py | 42 ++++- .../test/fixtures/flake8_pyi/PYI036.pyi | 42 ++++- .../src/checkers/ast/analyze/statement.rs | 2 +- .../flake8_pyi/rules/exit_annotations.rs | 144 +++++++++++++----- ...__flake8_pyi__tests__PYI036_PYI036.py.snap | 34 ++++- ..._flake8_pyi__tests__PYI036_PYI036.pyi.snap | 28 +++- 6 files changed, 249 insertions(+), 43 deletions(-) diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI036.py b/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI036.py index 57f71dc39e0148..ba7c8b43a2e8d0 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI036.py +++ b/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI036.py @@ -3,7 +3,7 @@ import typing from collections.abc import Awaitable from types import TracebackType -from typing import Any, Type +from typing import Any, Type, overload import _typeshed import typing_extensions @@ -73,3 +73,43 @@ async def __aexit__(self, /, typ: type[BaseException] | None, *args: Any) -> Awa class BadSix: def __exit__(self, typ, exc, tb, weird_extra_arg, extra_arg2 = None) -> None: ... # PYI036: Extra arg must have default async def __aexit__(self, typ, exc, tb, *, weird_extra_arg) -> None: ... # PYI036: kwargs must have default + +# Here come the overloads... + +class AcceptableOverload1: + @overload + def __exit__(self, exc_typ: None, exc: None, exc_tb: None) -> None: ... + @overload + def __exit__(self, exc_typ: type[BaseException], exc: BaseException, exc_tb: TracebackType) -> None: ... + def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None) -> None: ... + +# Using `object` or `Unused` in an overload definition is kinda strange, +# but let's allow it to be on the safe side +class AcceptableOverload2: + @overload + def __exit__(self, exc_typ: None, exc: None, exc_tb: object) -> None: ... + @overload + def __exit__(self, exc_typ: Unused, exc: BaseException, exc_tb: object) -> None: ... + def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None) -> None: ... + +class AcceptableOverload2: + # Just ignore any overloads that don't have exactly 3 annotated non-self parameters. + # We don't have the ability (yet) to do arbitrary checking + # of whether one function definition is a subtype of another... + @overload + def __exit__(self, exc_typ: bool, exc: bool, exc_tb: bool, weird_extra_arg: bool) -> None: ... + @overload + def __exit__(self, *args: object) -> None: ... + def __exit__(self, *args: object) -> None: ... + +class UnacceptableOverload1: + @overload + def __exit__(self, exc_typ: None, exc: None, tb: None) -> None: ... # Okay + @overload + def __exit__(self, exc_typ: Exception, exc: Exception, tb: TracebackType) -> None: ... # PYI036 + +class UnacceptableOverload2: + @overload + def __exit__(self, exc_typ: type[BaseException] | None, exc: None, tb: None) -> None: ... # PYI036 + @overload + def __exit__(self, exc_typ: object, exc: Exception, tb: builtins.TracebackType) -> None: ... # PYI036 diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI036.pyi b/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI036.pyi index 3f58441c94f215..947bd265ef1dc0 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI036.pyi +++ b/crates/ruff_linter/resources/test/fixtures/flake8_pyi/PYI036.pyi @@ -3,7 +3,7 @@ import types import typing from collections.abc import Awaitable from types import TracebackType -from typing import Any, Type +from typing import Any, Type, overload import _typeshed import typing_extensions @@ -80,3 +80,43 @@ def isolated_scope(): class ShouldNotError: def __exit__(self, typ: Type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None) -> None: ... + +# Here come the overloads... + +class AcceptableOverload1: + @overload + def __exit__(self, exc_typ: None, exc: None, exc_tb: None) -> None: ... + @overload + def __exit__(self, exc_typ: type[BaseException], exc: BaseException, exc_tb: TracebackType) -> None: ... + def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None) -> None: ... + +# Using `object` or `Unused` in an overload definition is kinda strange, +# but let's allow it to be on the safe side +class AcceptableOverload2: + @overload + def __exit__(self, exc_typ: None, exc: None, exc_tb: object) -> None: ... + @overload + def __exit__(self, exc_typ: Unused, exc: BaseException, exc_tb: object) -> None: ... + def __exit__(self, exc_typ: type[BaseException] | None, exc: BaseException | None, exc_tb: TracebackType | None) -> None: ... + +class AcceptableOverload2: + # Just ignore any overloads that don't have exactly 3 annotated non-self parameters. + # We don't have the ability (yet) to do arbitrary checking + # of whether one function definition is a subtype of another... + @overload + def __exit__(self, exc_typ: bool, exc: bool, exc_tb: bool, weird_extra_arg: bool) -> None: ... + @overload + def __exit__(self, *args: object) -> None: ... + def __exit__(self, *args: object) -> None: ... + +class UnacceptableOverload1: + @overload + def __exit__(self, exc_typ: None, exc: None, tb: None) -> None: ... # Okay + @overload + def __exit__(self, exc_typ: Exception, exc: Exception, tb: TracebackType) -> None: ... # PYI036 + +class UnacceptableOverload2: + @overload + def __exit__(self, exc_typ: type[BaseException] | None, exc: None, tb: None) -> None: ... # PYI036 + @overload + def __exit__(self, exc_typ: object, exc: Exception, tb: builtins.TracebackType) -> None: ... # PYI036 diff --git a/crates/ruff_linter/src/checkers/ast/analyze/statement.rs b/crates/ruff_linter/src/checkers/ast/analyze/statement.rs index 18326189b8cf14..0190de443528c2 100644 --- a/crates/ruff_linter/src/checkers/ast/analyze/statement.rs +++ b/crates/ruff_linter/src/checkers/ast/analyze/statement.rs @@ -174,7 +174,7 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { } } if checker.enabled(Rule::BadExitAnnotation) { - flake8_pyi::rules::bad_exit_annotation(checker, *is_async, name, parameters); + flake8_pyi::rules::bad_exit_annotation(checker, function_def); } if checker.enabled(Rule::RedundantNumericUnion) { flake8_pyi::rules::redundant_numeric_union(checker, parameters); diff --git a/crates/ruff_linter/src/rules/flake8_pyi/rules/exit_annotations.rs b/crates/ruff_linter/src/rules/flake8_pyi/rules/exit_annotations.rs index c7b18b3d8eb47d..e160e3ac15f71e 100644 --- a/crates/ruff_linter/src/rules/flake8_pyi/rules/exit_annotations.rs +++ b/crates/ruff_linter/src/rules/flake8_pyi/rules/exit_annotations.rs @@ -1,16 +1,16 @@ use std::fmt::{Display, Formatter}; use ruff_python_ast::{ - Expr, ExprBinOp, ExprSubscript, ExprTuple, Identifier, Operator, ParameterWithDefault, - Parameters, + Expr, ExprBinOp, ExprSubscript, ExprTuple, Operator, ParameterWithDefault, Parameters, + StmtFunctionDef, }; use smallvec::SmallVec; use ruff_diagnostics::{Diagnostic, Edit, Fix, FixAvailability, Violation}; use ruff_macros::{derive_message_formats, violation}; -use ruff_python_semantic::SemanticModel; -use ruff_text_size::Ranged; +use ruff_python_semantic::{analyze::visibility::is_overload, SemanticModel}; +use ruff_text_size::{Ranged, TextRange}; use crate::checkers::ast::Checker; @@ -68,6 +68,10 @@ impl Violation for BadExitAnnotation { ErrorKind::FirstArgBadAnnotation => format!("The first argument in `{method_name}` should be annotated with `object` or `type[BaseException] | None`"), ErrorKind::SecondArgBadAnnotation => format!("The second argument in `{method_name}` should be annotated with `object` or `BaseException | None`"), ErrorKind::ThirdArgBadAnnotation => format!("The third argument in `{method_name}` should be annotated with `object` or `types.TracebackType | None`"), + ErrorKind::UnrecognizedExitOverload => format!( + "Annotations for a three-argument `{method_name}` overload (excluding `self`) \ + should either be `None, None, None` or `type[BaseException], BaseException, types.TracebackType`" + ) } } @@ -104,37 +108,52 @@ enum ErrorKind { ThirdArgBadAnnotation, ArgsAfterFirstFourMustHaveDefault, AllKwargsMustHaveDefault, + UnrecognizedExitOverload, } /// PYI036 -pub(crate) fn bad_exit_annotation( - checker: &mut Checker, - is_async: bool, - name: &Identifier, - parameters: &Parameters, -) { +pub(crate) fn bad_exit_annotation(checker: &mut Checker, function: &StmtFunctionDef) { + let StmtFunctionDef { + is_async, + decorator_list, + name, + parameters, + .. + } = function; + let func_kind = match name.as_str() { "__exit__" if !is_async => FuncKind::Sync, - "__aexit__" if is_async => FuncKind::Async, + "__aexit__" if *is_async => FuncKind::Async, _ => return, }; - let positional_args = parameters + let non_self_positional_args = parameters .args .iter() + .skip(1) .chain(parameters.posonlyargs.iter()) - .collect::>(); + .collect::>(); + + if is_overload(decorator_list, checker.semantic()) { + check_positional_args_for_overloaded_method( + checker, + &non_self_positional_args, + func_kind, + parameters.range(), + ); + return; + } // If there are less than three positional arguments, at least one of them must be a star-arg, // and it must be annotated with `object`. - if positional_args.len() < 4 { + if non_self_positional_args.len() < 3 { check_short_args_list(checker, parameters, func_kind); } // Every positional argument (beyond the first four) must have a default. - for parameter in positional_args + for parameter in non_self_positional_args .iter() - .skip(4) + .skip(3) .filter(|parameter| parameter.default.is_none()) { checker.diagnostics.push(Diagnostic::new( @@ -161,7 +180,7 @@ pub(crate) fn bad_exit_annotation( )); } - check_positional_args(checker, &positional_args, func_kind); + check_positional_args_for_non_overloaded_method(checker, &non_self_positional_args, func_kind); } /// Determine whether a "short" argument list (i.e., an argument list with less than four elements) @@ -204,11 +223,11 @@ fn check_short_args_list(checker: &mut Checker, parameters: &Parameters, func_ki } } -/// Determines whether the positional arguments of an `__exit__` or `__aexit__` method are -/// annotated correctly. -fn check_positional_args( +/// Determines whether the positional arguments of an `__exit__` or `__aexit__` method +/// (that is not decorated with `@typing.overload`) are annotated correctly. +fn check_positional_args_for_non_overloaded_method( checker: &mut Checker, - positional_args: &[&ParameterWithDefault], + non_self_positional_args: &[&ParameterWithDefault], kind: FuncKind, ) { // For each argument, define the predicate against which to check the annotation. @@ -222,7 +241,7 @@ fn check_positional_args( (ErrorKind::ThirdArgBadAnnotation, is_traceback_type), ]; - for (arg, (error_info, predicate)) in positional_args.iter().skip(1).take(3).zip(validations) { + for (arg, (error_info, predicate)) in non_self_positional_args.iter().take(3).zip(validations) { let Some(annotation) = arg.parameter.annotation.as_ref() else { continue; }; @@ -249,6 +268,69 @@ fn check_positional_args( } } +/// Determines whether the positional arguments of an `__exit__` or `__aexit__` method +/// overload are annotated correctly. +fn check_positional_args_for_overloaded_method( + checker: &mut Checker, + non_self_positional_args: &[&ParameterWithDefault], + kind: FuncKind, + parameters_range: TextRange, +) { + let semantic = checker.semantic(); + if non_self_positional_args.len() != 3 { + return; + } + let non_self_annotations: SmallVec<[&Expr; 3]> = non_self_positional_args + .iter() + .filter_map(|arg| arg.parameter.annotation.as_deref()) + .collect(); + if non_self_annotations.len() != 3 { + return; + } + // We've now established that it's a function overload with 3 non-self positional arguments, + // where all arguments are annotated. It therefore follows that, in order for it to be + // correctly annotated, it must be one of the following two possible overloads: + // + // ``` + // @overload + // def __(a)exit__(self, typ: None, exc: None, tb: None) -> None: ... + // @overload + // def __(a)exit__(self, typ: type[BaseException], exc: BaseException, tb: TracebackType) -> None: ... + // ``` + // + // We'll allow small variations on either of these (if, e.g. a parameter is annotated + // with `object` or `_typeshed.Unused`). *Basically*, though, the rule is: + // - If the function overload matches *either* of those, it's okay. + // - If not: emit a diagnostic. + // + // Start by checking the first possibility: + if non_self_annotations.iter().all(|annotation| { + annotation.is_none_literal_expr() | is_object_or_unused(annotation, semantic) + }) { + return; + } + // Now check the second: + let matches_second_overload_variant = { + (is_base_exception_type(non_self_annotations[0], semantic) + || is_object_or_unused(non_self_annotations[0], semantic)) + && (semantic.match_builtin_expr(non_self_annotations[1], "BaseException") + || is_object_or_unused(non_self_annotations[1], semantic)) + && (is_traceback_type(non_self_annotations[2], semantic) + || is_object_or_unused(non_self_annotations[2], semantic)) + }; + if matches_second_overload_variant { + return; + } + // Okay, neither of them match... + checker.diagnostics.push(Diagnostic::new( + BadExitAnnotation { + func_kind: kind, + error_kind: ErrorKind::UnrecognizedExitOverload, + }, + parameters_range, + )); +} + /// Return the non-`None` annotation element of a PEP 604-style union or `Optional` annotation. fn non_none_annotation_element<'a>( annotation: &'a Expr, @@ -256,12 +338,9 @@ fn non_none_annotation_element<'a>( ) -> Option<&'a Expr> { // E.g., `typing.Union` or `typing.Optional` if let Expr::Subscript(ExprSubscript { value, slice, .. }) = annotation { - let qualified_name = semantic.resolve_qualified_name(value); + let qualified_name = semantic.resolve_qualified_name(value)?; - if qualified_name - .as_ref() - .is_some_and(|value| semantic.match_typing_qualified_name(value, "Optional")) - { + if semantic.match_typing_qualified_name(&qualified_name, "Optional") { return if slice.is_none_literal_expr() { None } else { @@ -269,16 +348,11 @@ fn non_none_annotation_element<'a>( }; } - if !qualified_name - .as_ref() - .is_some_and(|value| semantic.match_typing_qualified_name(value, "Union")) - { + if !semantic.match_typing_qualified_name(&qualified_name, "Union") { return None; } - let Expr::Tuple(ExprTuple { elts, .. }) = slice.as_ref() else { - return None; - }; + let ExprTuple { elts, .. } = slice.as_tuple_expr()?; let [left, right] = elts.as_slice() else { return None; @@ -318,7 +392,6 @@ fn non_none_annotation_element<'a>( fn is_object_or_unused(expr: &Expr, semantic: &SemanticModel) -> bool { semantic .resolve_qualified_name(expr) - .as_ref() .is_some_and(|qualified_name| { matches!( qualified_name.segments(), @@ -331,7 +404,6 @@ fn is_object_or_unused(expr: &Expr, semantic: &SemanticModel) -> bool { fn is_traceback_type(expr: &Expr, semantic: &SemanticModel) -> bool { semantic .resolve_qualified_name(expr) - .as_ref() .is_some_and(|qualified_name| { matches!(qualified_name.segments(), ["types", "TracebackType"]) }) diff --git a/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI036_PYI036.py.snap b/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI036_PYI036.py.snap index 5e891fa5d9323c..18e14a1b066538 100644 --- a/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI036_PYI036.py.snap +++ b/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI036_PYI036.py.snap @@ -156,6 +156,34 @@ PYI036.py:75:48: PYI036 All keyword-only arguments in `__aexit__` must have a de 74 | def __exit__(self, typ, exc, tb, weird_extra_arg, extra_arg2 = None) -> None: ... # PYI036: Extra arg must have default 75 | async def __aexit__(self, typ, exc, tb, *, weird_extra_arg) -> None: ... # PYI036: kwargs must have default | ^^^^^^^^^^^^^^^ PYI036 - | - - +76 | +77 | # Here come the overloads... + | + +PYI036.py:109:17: PYI036 Annotations for a three-argument `__exit__` overload (excluding `self`) should either be `None, None, None` or `type[BaseException], BaseException, types.TracebackType` + | +107 | def __exit__(self, exc_typ: None, exc: None, tb: None) -> None: ... # Okay +108 | @overload +109 | def __exit__(self, exc_typ: Exception, exc: Exception, tb: TracebackType) -> None: ... # PYI036 + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PYI036 +110 | +111 | class UnacceptableOverload2: + | + +PYI036.py:113:17: PYI036 Annotations for a three-argument `__exit__` overload (excluding `self`) should either be `None, None, None` or `type[BaseException], BaseException, types.TracebackType` + | +111 | class UnacceptableOverload2: +112 | @overload +113 | def __exit__(self, exc_typ: type[BaseException] | None, exc: None, tb: None) -> None: ... # PYI036 + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PYI036 +114 | @overload +115 | def __exit__(self, exc_typ: object, exc: Exception, tb: builtins.TracebackType) -> None: ... # PYI036 + | + +PYI036.py:115:17: PYI036 Annotations for a three-argument `__exit__` overload (excluding `self`) should either be `None, None, None` or `type[BaseException], BaseException, types.TracebackType` + | +113 | def __exit__(self, exc_typ: type[BaseException] | None, exc: None, tb: None) -> None: ... # PYI036 +114 | @overload +115 | def __exit__(self, exc_typ: object, exc: Exception, tb: builtins.TracebackType) -> None: ... # PYI036 + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PYI036 + | diff --git a/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI036_PYI036.pyi.snap b/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI036_PYI036.pyi.snap index 2d7817652e298e..e4892784a039f9 100644 --- a/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI036_PYI036.pyi.snap +++ b/crates/ruff_linter/src/rules/flake8_pyi/snapshots/ruff_linter__rules__flake8_pyi__tests__PYI036_PYI036.pyi.snap @@ -168,4 +168,30 @@ PYI036.pyi:75:48: PYI036 All keyword-only arguments in `__aexit__` must have a d | ^^^^^^^^^^^^^^^ PYI036 | - +PYI036.pyi:116:17: PYI036 Annotations for a three-argument `__exit__` overload (excluding `self`) should either be `None, None, None` or `type[BaseException], BaseException, types.TracebackType` + | +114 | def __exit__(self, exc_typ: None, exc: None, tb: None) -> None: ... # Okay +115 | @overload +116 | def __exit__(self, exc_typ: Exception, exc: Exception, tb: TracebackType) -> None: ... # PYI036 + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PYI036 +117 | +118 | class UnacceptableOverload2: + | + +PYI036.pyi:120:17: PYI036 Annotations for a three-argument `__exit__` overload (excluding `self`) should either be `None, None, None` or `type[BaseException], BaseException, types.TracebackType` + | +118 | class UnacceptableOverload2: +119 | @overload +120 | def __exit__(self, exc_typ: type[BaseException] | None, exc: None, tb: None) -> None: ... # PYI036 + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PYI036 +121 | @overload +122 | def __exit__(self, exc_typ: object, exc: Exception, tb: builtins.TracebackType) -> None: ... # PYI036 + | + +PYI036.pyi:122:17: PYI036 Annotations for a three-argument `__exit__` overload (excluding `self`) should either be `None, None, None` or `type[BaseException], BaseException, types.TracebackType` + | +120 | def __exit__(self, exc_typ: type[BaseException] | None, exc: None, tb: None) -> None: ... # PYI036 +121 | @overload +122 | def __exit__(self, exc_typ: object, exc: Exception, tb: builtins.TracebackType) -> None: ... # PYI036 + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PYI036 + |