Skip to content

Commit

Permalink
trio101: never yield inside a nursery or cancel scope
Browse files Browse the repository at this point in the history
  • Loading branch information
jakkdl committed Jul 19, 2022
1 parent dd1ba98 commit 9c8c49a
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 2 deletions.
44 changes: 43 additions & 1 deletion flake8_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import ast
import tokenize
from typing import Any, Generator, Iterable, List, Optional, Tuple, Type, Union
from typing import Any, Generator, Iterable, List, Optional, Set, Tuple, Type, Union

# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
__version__ = "22.7.1"
Expand Down Expand Up @@ -40,13 +40,24 @@ class Visitor(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()
self.problems: List[Error] = []
self.safe_yields: Set[ast.Yield] = set()

def visit_With(self, node: ast.With) -> None:
self.check_for_trio100(node)
self.check_for_trio101(node)
self.generic_visit(node)

def visit_AsyncWith(self, node: ast.AsyncWith) -> None:
self.check_for_trio100(node)
self.check_for_trio101(node)
self.generic_visit(node)

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self.trio101_mark_yields_safe(node)
self.generic_visit(node)

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
self.trio101_mark_yields_safe(node)
self.generic_visit(node)

def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]) -> None:
Expand All @@ -58,6 +69,36 @@ def check_for_trio100(self, node: Union[ast.With, ast.AsyncWith]) -> None:
make_error(TRIO100, item.lineno, item.col_offset, (call,))
)

def trio101_mark_yields_safe(
self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]
) -> None:
if any(
isinstance(d, ast.Name)
and d.id in ("contextmanager", "asynccontextmanager")
for d in node.decorator_list
):
self.safe_yields.update(
{x for x in ast.walk(node) if isinstance(x, ast.Yield)}
)

def check_for_trio101(self, node: Union[ast.With, ast.AsyncWith]) -> None:
for item in (i.context_expr for i in node.items):
call = is_trio_call(
item,
"open_nursery",
"fail_after",
"fail_at",
"move_on_after",
"move_at",
)
if call and any(
isinstance(x, ast.Yield) and x not in self.safe_yields
for x in ast.walk(node)
):
self.problems.append(
make_error(TRIO101, item.lineno, item.col_offset, (call,))
)


class Plugin:
name = __name__
Expand All @@ -83,3 +124,4 @@ def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]:


TRIO100 = "TRIO100: {} context contains no checkpoints, add `await trio.sleep(0)`"
TRIO101 = "TRIO101: {} never yield inside a nursery or cancel scope"
9 changes: 8 additions & 1 deletion tests/test_flake8_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from hypothesis import HealthCheck, given, settings
from hypothesmith import from_grammar

from flake8_trio import TRIO100, Error, Plugin, Visitor, make_error
from flake8_trio import TRIO100, TRIO101, Error, Plugin, Visitor, make_error


class Flake8TrioTestCase(unittest.TestCase):
Expand Down Expand Up @@ -40,6 +40,13 @@ def test_trio100_py39(self):
make_error(TRIO100, 14, 8, ("trio.move_on_after",)),
)

def test_trio101(self):
self.assert_expected_errors(
"trio101.py",
make_error(TRIO101, 7, 9, ("trio.open_nursery",)),
make_error(TRIO101, 12, 15, ("trio.open_nursery",)),
)


@pytest.mark.fuzz
class TestFuzz(unittest.TestCase):
Expand Down
25 changes: 25 additions & 0 deletions tests/trio101.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from contextlib import asynccontextmanager, contextmanager

import trio


def foo0():
with trio.open_nursery() as _:
yield 1


async def foo1():
async with trio.open_nursery() as _:
yield 1


@contextmanager
def foo2():
with trio.open_nursery() as _:
yield 1


@asynccontextmanager
async def foo3():
async with trio.open_nursery() as _:
yield 1

0 comments on commit 9c8c49a

Please sign in to comment.