diff --git a/docs/changelog.md b/docs/changelog.md index bf480a65..1b651402 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,8 @@ ## Unreleased +- Fix "This function should have an @asynq() decorator" + false positive on lambdas (#399) - Fix compatibility between Union and Annotated (#397) - Fix potential incorrect inferred return value for unannotated functions (#396) diff --git a/pyanalyze/name_check_visitor.py b/pyanalyze/name_check_visitor.py index 80a2123c..9b142e5d 100644 --- a/pyanalyze/name_check_visitor.py +++ b/pyanalyze/name_check_visitor.py @@ -1791,7 +1791,11 @@ def _visit_function_body(self, function_info: FunctionInfo) -> FunctionResult: with qcore.override( self, "state", VisitorState.collect_names - ), qcore.override(self, "return_values", []): + ), qcore.override( + self, "return_values", [] + ), self.yield_checker.set_function_node( + node + ): if isinstance(node, ast.Lambda): self.visit(node.body) else: @@ -1806,7 +1810,11 @@ def _visit_function_body(self, function_info: FunctionInfo) -> FunctionResult: with qcore.override(self, "current_class", None), qcore.override( self, "state", VisitorState.check_names - ), qcore.override(self, "return_values", []): + ), qcore.override( + self, "return_values", [] + ), self.yield_checker.set_function_node( + node + ): if isinstance(node, ast.Lambda): return_values = [self.visit(node.body)] else: diff --git a/pyanalyze/test_yield_checker.py b/pyanalyze/test_yield_checker.py index 2642c69a..434737db 100644 --- a/pyanalyze/test_yield_checker.py +++ b/pyanalyze/test_yield_checker.py @@ -568,6 +568,22 @@ def capybara(fn, fn2): """, ) + @assert_passes() + def test_lambda(self): + from asynq import asynq + from asynq.tools import afilter + + @asynq() + def inner(obj): + return bool(obj) + + @asynq() + def outer(objs): + objs = yield afilter.asynq( + asynq()(lambda obj: not (yield inner.asynq(obj))), objs + ) + return objs + class TestDuplicateYield(TestNameCheckVisitorBase): @assert_passes() diff --git a/pyanalyze/yield_checker.py b/pyanalyze/yield_checker.py index aa9ed689..c747ce22 100644 --- a/pyanalyze/yield_checker.py +++ b/pyanalyze/yield_checker.py @@ -31,6 +31,8 @@ Tuple, ) +from pyanalyze.functions import FunctionNode + from .asynq_checker import AsyncFunctionKind from .error_code import ErrorCode from .value import Value, KnownValue, UnboundMethodValue, UNINITIALIZED_VALUE @@ -181,7 +183,11 @@ class YieldChecker: previous_yield: Optional[ast.Yield] = None statement_for_previous_yield: Optional[ast.stmt] = None used_varnames: Set[str] = field(default_factory=set) - added_decorator: bool = False + current_function_node: Optional[FunctionNode] = None + alerted_nodes: Set[FunctionNode] = field(default_factory=set) + + def set_function_node(self, node: FunctionNode) -> ContextManager[None]: + return qcore.override(self, "current_function_node", node) @contextlib.contextmanager def check_yield( @@ -235,24 +241,27 @@ def reset_yield_checks(self) -> None: def record_call(self, value: Value, node: ast.Call) -> None: if ( - not self.in_non_async_yield + self.current_function_node is None + or not self.in_non_async_yield or not self._is_async_call(value, node.func) - or self.added_decorator + or self.current_function_node in self.alerted_nodes ): return - # prevent ourselves from adding the decorator to the same function multiple times - self.added_decorator = True - func_node = self.visitor.node_context.nearest_enclosing(ast.FunctionDef) + # async functions can't be asynq. Lambdas can be but can't have decorators. + if isinstance(self.current_function_node, (ast.Lambda, ast.AsyncFunctionDef)): + return + lines = self.visitor._lines() # this doesn't handle decorator order, it just adds @asynq() right before the def - i = func_node.lineno - 1 + i = self.current_function_node.lineno - 1 def_line = lines[i] indentation = len(def_line) - len(def_line.lstrip()) replacement = Replacement([i + 1], [" " * indentation + "@asynq()\n", def_line]) self.visitor.show_error( node, error_code=ErrorCode.missing_asynq, replacement=replacement ) + self.alerted_nodes.add(self.current_function_node) # Internal part