Skip to content

Commit

Permalink
Merge pull request #17 from jakkdl/7_async_iterable_checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
Zac-HD authored Aug 8, 2022
2 parents dc2cc8a + 71a66b4 commit ff2d0b3
Show file tree
Hide file tree
Showing 5 changed files with 1,002 additions and 152 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Changelog
*[CalVer, YY.month.patch](https://calver.org/)*

## 22.8.2
- Merged TRIO108 into TRIO107
- TRIO108 now handles checkpointing in async iterators

## 22.8.1
- Added TRIO109: Async definitions should not have a `timeout` parameter. Use `trio.[fail/move_on]_[at/after]`
- Added TRIO110: `while <condition>: await trio.sleep()` should be replaced by a `trio.Event`.
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ pip install flake8-trio
- **TRIO104**: `Cancelled` and `BaseException` must be re-raised - when a user tries to `return` or `raise` a different exception.
- **TRIO105**: Calling a trio async function without immediately `await`ing it.
- **TRIO106**: trio must be imported with `import trio` for the linter to work.
- **TRIO107**: Async functions must have at least one checkpoint on every code path, unless an exception is raised.
- **TRIO108**: Early return from async function must have at least one checkpoint on every code path before it, unless an exception is raised.
Checkpoints are `await`, `async with` `async for`.
- **TRIO107**: exit or `return` from async function with no guaranteed checkpoint or exception since function definition.
- **TRIO108**: exit, yield or return from async iterable with no guaranteed checkpoint since possible function entry (yield or function definition)
Checkpoints are `await`, `async for`, and `async with` (on one of enter/exit).
- **TRIO109**: Async function definition with a `timeout` parameter - use `trio.[fail/move_on]_[after/at]` instead
- **TRIO110**: `while <condition>: await trio.sleep()` should be replaced by a `trio.Event`.
270 changes: 211 additions & 59 deletions flake8_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, Type, Union

# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
__version__ = "22.8.1"
__version__ = "22.8.2"


class Statement(NamedTuple):
Expand All @@ -35,8 +35,8 @@ class Statement(NamedTuple):
"TRIO104": "Cancelled (and therefore BaseException) must be re-raised",
"TRIO105": "trio async function {} must be immediately awaited",
"TRIO106": "trio must be imported with `import trio` for the linter to work",
"TRIO107": "Async functions must have at least one checkpoint on every code path, unless an exception is raised",
"TRIO108": "Early return from async function must have at least one checkpoint on every code path before it.",
"TRIO107": "{0} from async function with no guaranteed checkpoint or exception since function definition on line {1.lineno}",
"TRIO108": "{0} from async iterable with no guaranteed checkpoint since {1.name} on line {1.lineno}",
"TRIO109": "Async function definition with a `timeout` parameter - use `trio.[fail/move_on]_[after/at]` instead",
"TRIO110": "`while <condition>: await trio.sleep()` should be replaced by a `trio.Event`.",
}
Expand Down Expand Up @@ -68,6 +68,7 @@ class Flake8TrioVisitor(ast.NodeVisitor):
def __init__(self):
super().__init__()
self._problems: List[Error] = []
self.suppress_errors = False

@classmethod
def run(cls, tree: ast.AST) -> Iterable[Error]:
Expand All @@ -90,9 +91,10 @@ def visit_nodes(
visit(node)

def error(self, error: str, node: HasLineInfo, *args: Any, **kwargs: Any):
self._problems.append(
make_error(error, node.lineno, node.col_offset, *args, **kwargs)
)
if not self.suppress_errors:
self._problems.append(
make_error(error, node.lineno, node.col_offset, *args, **kwargs)
)

def get_state(self, *attrs: str) -> Dict[str, Any]:
if not attrs:
Expand All @@ -103,6 +105,10 @@ def set_state(self, attrs: Dict[str, Any]):
for attr, value in attrs.items():
setattr(self, attr, value)

def walk(self, *body: ast.AST) -> Iterable[ast.AST]:
for b in body:
yield from ast.walk(b)


class TrioScope:
def __init__(self, node: ast.Call, funcname: str, packagename: str):
Expand Down Expand Up @@ -561,105 +567,251 @@ def visit_Call(self, node: ast.Call):
class Visitor107_108(Flake8TrioVisitor):
def __init__(self):
super().__init__()
self.all_await = True
self.yield_count = 0

self.always_checkpoint: Optional[Statement] = None
self.checkpoint_continue: Optional[Statement] = None
self.checkpoint_break: Optional[Statement] = None

self.default = self.get_state()

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
outer = self.all_await
if has_decorator(node.decorator_list, "overload"):
return

outer = self.get_state()
self.set_state(self.default)

self.always_checkpoint = Statement("function definition", node.lineno)

# do not require checkpointing if overloading
self.all_await = has_decorator(node.decorator_list, "overload")
self.generic_visit(node)
self.check_function_exit(node)

if not self.all_await:
self.error("TRIO107", node)
self.set_state(outer)

self.all_await = outer
def check_function_exit(self, node: Union[ast.Return, ast.AsyncFunctionDef]):
# error if function exits w/o guaranteed checkpoint since function entry
method = "return" if isinstance(node, ast.Return) else "exit"

if self.always_checkpoint is not None:
if self.yield_count:
self.error("TRIO108", node, method, self.always_checkpoint)
else:
self.error("TRIO107", node, method, self.always_checkpoint)

def visit_Return(self, node: ast.Return):
self.generic_visit(node)
if not self.all_await:
self.error("TRIO108", node)
self.check_function_exit(node)

# avoid duplicate error messages
self.all_await = True
self.always_checkpoint = None

# disregard raise's in nested functions
# disregard checkpoints in nested function definitions
def visit_FunctionDef(self, node: ast.FunctionDef):
outer = self.all_await
outer = self.get_state()
self.set_state(self.default)
self.generic_visit(node)
self.all_await = outer
self.set_state(outer)

# checkpoint functions
def visit_Await(
self, node: Union[ast.Await, ast.AsyncFor, ast.AsyncWith, ast.Raise]
):
def visit_Await(self, node: Union[ast.Await, ast.Raise]):
# the expression being awaited is not checkpointed
# so only set checkpoint after the await node
self.generic_visit(node)
self.all_await = True

visit_AsyncFor = visit_Await
visit_AsyncWith = visit_Await
self.always_checkpoint = None

# raising exception means we don't need to checkpoint so we can treat it as one
visit_Raise = visit_Await

# valid checkpoint if there's valid checkpoints (or raise) in at least one of:
# (try or else) and all excepts
# finally
# guaranteed to checkpoint on at least one of enter and exit
# if it checkpoints on entry and there's a yield in it, we can't treat it as checkpoint
# but it may not checkpoint on entry, so yields inside need to raise problem
def visit_AsyncWith(self, node: ast.AsyncWith):
self.visit_nodes(node.items)
prebody_yield_count = self.yield_count

# there's no guarantee of checkpoint before entry
self.visit_nodes(node.body)

# no yield in body, treat as checkpoint
if prebody_yield_count == self.yield_count:
self.always_checkpoint = None

# error if no checkpoint since earlier yield or function entry
def visit_Yield(self, node: ast.Yield):
self.generic_visit(node)
self.yield_count += 1
if self.always_checkpoint is not None:
self.error("TRIO108", node, "yield", self.always_checkpoint)

# mark as requiring checkpoint after
self.always_checkpoint = Statement("yield", node.lineno)

# valid checkpoint if there's valid checkpoints (or raise) in:
# (try or else) and all excepts, or in finally
#
# try can jump into any except or into the finally* at any point during it's
# execution so we need to make sure except & finally can handle worst-case
# * unless there's a bare except / except BaseException - not implemented.
def visit_Try(self, node: ast.Try):
if self.all_await:
self.generic_visit(node)
return
# except & finally guaranteed to enter with checkpoint if checkpointed
# before try and no yield in try body.
body_always_checkpoint = self.always_checkpoint
for inner_node in self.walk(*node.body):
if isinstance(inner_node, ast.Yield):
body_always_checkpoint = Statement("yield", inner_node.lineno)
break

# check try body
self.visit_nodes(node.body)
body_await = self.all_await
self.all_await = False

# save state at end of try for entering else
try_checkpoint = self.always_checkpoint

# check that all except handlers checkpoint (await or most likely raise)
all_except_await = True
all_except_checkpoint: Optional[Statement] = None
for handler in node.handlers:
# enter with worst case of try
self.always_checkpoint = body_always_checkpoint

self.visit_nodes(handler)
all_except_await &= self.all_await
self.all_await = False

if self.always_checkpoint is not None:
all_except_checkpoint = self.always_checkpoint

# check else
# if else runs it's after all of try, so restore state to back then
self.always_checkpoint = try_checkpoint
self.visit_nodes(node.orelse)

# (try or else) and all excepts
self.all_await = (body_await or self.all_await) and all_except_await
# checkpoint if else checkpoints, and all excepts checkpoint
if all_except_checkpoint is not None:
self.always_checkpoint = all_except_checkpoint

# finally can check on it's own
self.visit_nodes(node.finalbody)
# if there's no finally, don't restore state from try
if node.finalbody:
# can enter from try, else, or any except
if body_always_checkpoint is not None:
self.always_checkpoint = body_always_checkpoint
self.visit_nodes(node.finalbody)

# valid checkpoint if both body and orelse have checkpoints
# valid checkpoint if both body and orelse checkpoint
def visit_If(self, node: Union[ast.If, ast.IfExp]):
if self.all_await:
self.generic_visit(node)
return

# ignore checkpoints in condition
# visit condition
self.visit_nodes(node.test)
self.all_await = False
outer = self.get_state("always_checkpoint")

# check body
# visit body
self.visit_nodes(node.body)
body_await = self.all_await
self.all_await = False
body_outer = self.get_state("always_checkpoint")

# reset to after condition and visit orelse
self.set_state(outer)
self.visit_nodes(node.orelse)

# checkpoint if both body and else
self.all_await = body_await and self.all_await
# if body failed, reset to that state
if body_outer["always_checkpoint"] is not None:
self.set_state(body_outer)

# otherwise keep state (fail or not) as it was after orelse

# inline if
visit_IfExp = visit_If

# ignore checkpoints in loops due to continue/break shenanigans
def visit_While(self, node: Union[ast.While, ast.For]):
outer = self.all_await
self.generic_visit(node)
self.all_await = outer
# Check for yields w/o checkpoint inbetween due to entering loop body the first time,
# after completing all of loop body, and after any continues.
# yield in else have same requirement
# state after the loop same as above, and in addition the state at any break
def visit_loop(self, node: Union[ast.While, ast.For, ast.AsyncFor]):
# save state in case of nested loops
outer = self.get_state(
"checkpoint_continue", "checkpoint_break", "suppress_errors"
)

# visit condition
if isinstance(node, ast.While):
self.visit_nodes(node.test)
else:
self.visit_nodes(node.target)
self.visit_nodes(node.iter)

self.checkpoint_continue = None
pre_body_always_checkpoint = self.always_checkpoint

# AsyncFor guaranteed checkpoint at every iteration
if isinstance(node, ast.AsyncFor):
pre_body_always_checkpoint = None
self.always_checkpoint = None

# if we normally enter loop with checkpoint, check for worst-case start of loop
# due to `continue` or multiple iterations
elif self.always_checkpoint is None:
# silently check if body unsets yield
# so we later can check if body errors out on worst case of entering
self.suppress_errors = True

# self.checkpoint_continue is set to False if loop body ever does
# continue with self.always_checkpoint == False
self.visit_nodes(node.body)

self.suppress_errors = outer["suppress_errors"]

if self.checkpoint_continue is not None:
self.always_checkpoint = self.checkpoint_continue

self.checkpoint_break = None
self.visit_nodes(node.body)

# AsyncFor guarantees checkpoint on running out of iterable
# so reset checkpoint state at end of loop. (but not state at break)
if isinstance(node, ast.AsyncFor):
self.always_checkpoint = None
else:
# enter orelse with worst case:
# loop body might execute fully before entering orelse
# (current state of self.always_checkpoint)
# or not at all
if pre_body_always_checkpoint is not None:
self.always_checkpoint = pre_body_always_checkpoint
# or at a continue
elif self.checkpoint_continue is not None:
self.always_checkpoint = self.checkpoint_continue

# visit orelse
self.visit_nodes(node.orelse)

# We may exit from:
# orelse (which covers no body, body until continue, and all body)
# break
if self.checkpoint_break is not None:
self.always_checkpoint = self.checkpoint_break

# reset state in case of nested loops
self.set_state(outer)

visit_For = visit_While
visit_While = visit_loop
visit_For = visit_loop
visit_AsyncFor = visit_loop

# save state in case of continue/break at a point not guaranteed to checkpoint
def visit_Continue(self, node: ast.Continue):
if self.always_checkpoint is not None:
self.checkpoint_continue = self.always_checkpoint

def visit_Break(self, node: ast.Break):
if self.always_checkpoint is not None:
self.checkpoint_break = self.always_checkpoint

# first node in a condition is guaranteed to run, but may shortcut so checkpoints
# in remaining nodes are not guaranteed
# Not fully implemented: worst case shortcut with yields in condition
def visit_BoolOp(self, node: ast.BoolOp):
self.visit(node.op)
self.visit_nodes(node.values[:1])
outer = self.always_checkpoint
self.visit_nodes(node.values[1:])

self.always_checkpoint = outer


class Plugin:
Expand Down
Loading

0 comments on commit ff2d0b3

Please sign in to comment.