From 9c8c49a1f8603f44950ba403989807cfcd661562 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Mon, 18 Jul 2022 20:29:00 +0200 Subject: [PATCH] trio101: never yield inside a nursery or cancel scope --- flake8_trio.py | 44 ++++++++++++++++++++++++++++++++++++++- tests/test_flake8_trio.py | 9 +++++++- tests/trio101.py | 25 ++++++++++++++++++++++ 3 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 tests/trio101.py diff --git a/flake8_trio.py b/flake8_trio.py index 958edd5..a8e7b8c 100644 --- a/flake8_trio.py +++ b/flake8_trio.py @@ -11,7 +11,7 @@ import ast import tokenize -from typing import Any, Generator, Iterable, List, Optional, Tuple, Type, Union +from typing import Any, Generator, Iterable, List, Optional, Set, Tuple, Type, Union # CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1" __version__ = "22.7.1" @@ -40,13 +40,24 @@ class Visitor(ast.NodeVisitor): def __init__(self) -> None: super().__init__() self.problems: List[Error] = [] + self.safe_yields: Set[ast.Yield] = set() def visit_With(self, node: ast.With) -> None: self.check_for_trio100(node) + self.check_for_trio101(node) self.generic_visit(node) def visit_AsyncWith(self, node: ast.AsyncWith) -> None: self.check_for_trio100(node) + self.check_for_trio101(node) + self.generic_visit(node) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + self.trio101_mark_yields_safe(node) + self.generic_visit(node) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + self.trio101_mark_yields_safe(node) self.generic_visit(node) def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]) -> None: @@ -58,6 +69,36 @@ def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]) -> None: make_error(TRIO100, item.lineno, item.col_offset, (call,)) ) + def trio101_mark_yields_safe( + self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef] + ) -> None: + if any( + isinstance(d, ast.Name) + and d.id in ("contextmanager", "asynccontextmanager") + for d in node.decorator_list + ): + self.safe_yields.update( + {x for x in ast.walk(node) if isinstance(x, ast.Yield)} + ) + + def check_for_trio101(self, node: Union[ast.With, ast.AsyncWith]) -> None: + for item in (i.context_expr for i in node.items): + call = is_trio_call( + item, + "open_nursery", + "fail_after", + "fail_at", + "move_on_after", + "move_at", + ) + if call and any( + isinstance(x, ast.Yield) and x not in self.safe_yields + for x in ast.walk(node) + ): + self.problems.append( + make_error(TRIO101, item.lineno, item.col_offset, (call,)) + ) + class Plugin: name = __name__ @@ -83,3 +124,4 @@ def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]: TRIO100 = "TRIO100: {} context contains no checkpoints, add `await trio.sleep(0)`" +TRIO101 = "TRIO101: {} never yield inside a nursery or cancel scope" diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py index 2d3dd22..c53833b 100644 --- a/tests/test_flake8_trio.py +++ b/tests/test_flake8_trio.py @@ -9,7 +9,7 @@ from hypothesis import HealthCheck, given, settings from hypothesmith import from_grammar -from flake8_trio import TRIO100, Error, Plugin, Visitor, make_error +from flake8_trio import TRIO100, TRIO101, Error, Plugin, Visitor, make_error class Flake8TrioTestCase(unittest.TestCase): @@ -40,6 +40,13 @@ def test_trio100_py39(self): make_error(TRIO100, 14, 8, ("trio.move_on_after",)), ) + def test_trio101(self): + self.assert_expected_errors( + "trio101.py", + make_error(TRIO101, 7, 9, ("trio.open_nursery",)), + make_error(TRIO101, 12, 15, ("trio.open_nursery",)), + ) + @pytest.mark.fuzz class TestFuzz(unittest.TestCase): diff --git a/tests/trio101.py b/tests/trio101.py new file mode 100644 index 0000000..0b9a827 --- /dev/null +++ b/tests/trio101.py @@ -0,0 +1,25 @@ +from contextlib import asynccontextmanager, contextmanager + +import trio + + +def foo0(): + with trio.open_nursery() as _: + yield 1 + + +async def foo1(): + async with trio.open_nursery() as _: + yield 1 + + +@contextmanager +def foo2(): + with trio.open_nursery() as _: + yield 1 + + +@asynccontextmanager +async def foo3(): + async with trio.open_nursery() as _: + yield 1