From 9c8c49a1f8603f44950ba403989807cfcd661562 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Mon, 18 Jul 2022 20:29:00 +0200
Subject: [PATCH] trio101: never yield inside a nursery or cancel scope
---
flake8_trio.py | 44 ++++++++++++++++++++++++++++++++++++++-
tests/test_flake8_trio.py | 9 +++++++-
tests/trio101.py | 25 ++++++++++++++++++++++
3 files changed, 76 insertions(+), 2 deletions(-)
create mode 100644 tests/trio101.py
diff --git a/flake8_trio.py b/flake8_trio.py
index 958edd5..a8e7b8c 100644
--- a/flake8_trio.py
+++ b/flake8_trio.py
@@ -11,7 +11,7 @@
import ast
import tokenize
-from typing import Any, Generator, Iterable, List, Optional, Tuple, Type, Union
+from typing import Any, Generator, Iterable, List, Optional, Set, Tuple, Type, Union
# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
__version__ = "22.7.1"
@@ -40,13 +40,24 @@ class Visitor(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()
self.problems: List[Error] = []
+ self.safe_yields: Set[ast.Yield] = set()
def visit_With(self, node: ast.With) -> None:
self.check_for_trio100(node)
+ self.check_for_trio101(node)
self.generic_visit(node)
def visit_AsyncWith(self, node: ast.AsyncWith) -> None:
self.check_for_trio100(node)
+ self.check_for_trio101(node)
+ self.generic_visit(node)
+
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
+ self.trio101_mark_yields_safe(node)
+ self.generic_visit(node)
+
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
+ self.trio101_mark_yields_safe(node)
self.generic_visit(node)
def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]) -> None:
@@ -58,6 +69,36 @@ def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]) -> None:
make_error(TRIO100, item.lineno, item.col_offset, (call,))
)
+ def trio101_mark_yields_safe(
+ self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
+ ) -> None:
+ if any(
+ isinstance(d, ast.Name)
+ and d.id in ("contextmanager", "asynccontextmanager")
+ for d in node.decorator_list
+ ):
+ self.safe_yields.update(
+ {x for x in ast.walk(node) if isinstance(x, ast.Yield)}
+ )
+
+ def check_for_trio101(self, node: Union[ast.With, ast.AsyncWith]) -> None:
+ for item in (i.context_expr for i in node.items):
+ call = is_trio_call(
+ item,
+ "open_nursery",
+ "fail_after",
+ "fail_at",
+ "move_on_after",
+ "move_at",
+ )
+ if call and any(
+ isinstance(x, ast.Yield) and x not in self.safe_yields
+ for x in ast.walk(node)
+ ):
+ self.problems.append(
+ make_error(TRIO101, item.lineno, item.col_offset, (call,))
+ )
+
class Plugin:
name = __name__
@@ -83,3 +124,4 @@ def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]:
TRIO100 = "TRIO100: {} context contains no checkpoints, add `await trio.sleep(0)`"
+TRIO101 = "TRIO101: {} never yield inside a nursery or cancel scope"
diff --git a/tests/test_flake8_trio.py b/tests/test_flake8_trio.py
index 2d3dd22..c53833b 100644
--- a/tests/test_flake8_trio.py
+++ b/tests/test_flake8_trio.py
@@ -9,7 +9,7 @@
from hypothesis import HealthCheck, given, settings
from hypothesmith import from_grammar
-from flake8_trio import TRIO100, Error, Plugin, Visitor, make_error
+from flake8_trio import TRIO100, TRIO101, Error, Plugin, Visitor, make_error
class Flake8TrioTestCase(unittest.TestCase):
@@ -40,6 +40,13 @@ def test_trio100_py39(self):
make_error(TRIO100, 14, 8, ("trio.move_on_after",)),
)
+ def test_trio101(self):
+ self.assert_expected_errors(
+ "trio101.py",
+ make_error(TRIO101, 7, 9, ("trio.open_nursery",)),
+ make_error(TRIO101, 12, 15, ("trio.open_nursery",)),
+ )
+
@pytest.mark.fuzz
class TestFuzz(unittest.TestCase):
diff --git a/tests/trio101.py b/tests/trio101.py
new file mode 100644
index 0000000..0b9a827
--- /dev/null
+++ b/tests/trio101.py
@@ -0,0 +1,25 @@
+from contextlib import asynccontextmanager, contextmanager
+
+import trio
+
+
+def foo0():
+ with trio.open_nursery() as _:
+ yield 1
+
+
+async def foo1():
+ async with trio.open_nursery() as _:
+ yield 1
+
+
+@contextmanager
+def foo2():
+ with trio.open_nursery() as _:
+ yield 1
+
+
+@asynccontextmanager
+async def foo3():
+ async with trio.open_nursery() as _:
+ yield 1