Skip to content

Commit

Permalink
Fix missing_asynq false positive on lambdas (#399)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Jan 11, 2022
1 parent 3f08b71 commit a305d9b
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 9 deletions.
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Unreleased

- Fix "This function should have an @asynq() decorator"
false positive on lambdas (#399)
- Fix compatibility between Union and Annotated (#397)
- Fix potential incorrect inferred return value for
unannotated functions (#396)
Expand Down
12 changes: 10 additions & 2 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,7 +1791,11 @@ def _visit_function_body(self, function_info: FunctionInfo) -> FunctionResult:

with qcore.override(
self, "state", VisitorState.collect_names
), qcore.override(self, "return_values", []):
), qcore.override(
self, "return_values", []
), self.yield_checker.set_function_node(
node
):
if isinstance(node, ast.Lambda):
self.visit(node.body)
else:
Expand All @@ -1806,7 +1810,11 @@ def _visit_function_body(self, function_info: FunctionInfo) -> FunctionResult:

with qcore.override(self, "current_class", None), qcore.override(
self, "state", VisitorState.check_names
), qcore.override(self, "return_values", []):
), qcore.override(
self, "return_values", []
), self.yield_checker.set_function_node(
node
):
if isinstance(node, ast.Lambda):
return_values = [self.visit(node.body)]
else:
Expand Down
16 changes: 16 additions & 0 deletions pyanalyze/test_yield_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,22 @@ def capybara(fn, fn2):
""",
)

@assert_passes()
def test_lambda(self):
from asynq import asynq
from asynq.tools import afilter

@asynq()
def inner(obj):
return bool(obj)

@asynq()
def outer(objs):
objs = yield afilter.asynq(
asynq()(lambda obj: not (yield inner.asynq(obj))), objs
)
return objs


class TestDuplicateYield(TestNameCheckVisitorBase):
@assert_passes()
Expand Down
23 changes: 16 additions & 7 deletions pyanalyze/yield_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
Tuple,
)

from pyanalyze.functions import FunctionNode

from .asynq_checker import AsyncFunctionKind
from .error_code import ErrorCode
from .value import Value, KnownValue, UnboundMethodValue, UNINITIALIZED_VALUE
Expand Down Expand Up @@ -181,7 +183,11 @@ class YieldChecker:
previous_yield: Optional[ast.Yield] = None
statement_for_previous_yield: Optional[ast.stmt] = None
used_varnames: Set[str] = field(default_factory=set)
added_decorator: bool = False
current_function_node: Optional[FunctionNode] = None
alerted_nodes: Set[FunctionNode] = field(default_factory=set)

def set_function_node(self, node: FunctionNode) -> ContextManager[None]:
return qcore.override(self, "current_function_node", node)

@contextlib.contextmanager
def check_yield(
Expand Down Expand Up @@ -235,24 +241,27 @@ def reset_yield_checks(self) -> None:

def record_call(self, value: Value, node: ast.Call) -> None:
if (
not self.in_non_async_yield
self.current_function_node is None
or not self.in_non_async_yield
or not self._is_async_call(value, node.func)
or self.added_decorator
or self.current_function_node in self.alerted_nodes
):
return

# prevent ourselves from adding the decorator to the same function multiple times
self.added_decorator = True
func_node = self.visitor.node_context.nearest_enclosing(ast.FunctionDef)
# async functions can't be asynq. Lambdas can be but can't have decorators.
if isinstance(self.current_function_node, (ast.Lambda, ast.AsyncFunctionDef)):
return

lines = self.visitor._lines()
# this doesn't handle decorator order, it just adds @asynq() right before the def
i = func_node.lineno - 1
i = self.current_function_node.lineno - 1
def_line = lines[i]
indentation = len(def_line) - len(def_line.lstrip())
replacement = Replacement([i + 1], [" " * indentation + "@asynq()\n", def_line])
self.visitor.show_error(
node, error_code=ErrorCode.missing_asynq, replacement=replacement
)
self.alerted_nodes.add(self.current_function_node)

# Internal part

Expand Down

0 comments on commit a305d9b

Please sign in to comment.