Skip to content

Commit

Permalink
refactor trying to clean the code and add comments where conditions o…
Browse files Browse the repository at this point in the history
…n instances of walrus operator
  • Loading branch information
aless10 committed Mar 1, 2023
1 parent 52f818d commit 0bc4bdc
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
29 changes: 14 additions & 15 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@

assertstate_key = StashKey["AssertionState"]()


# pytest caches rewritten pycs in pycache dirs
PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
PYC_EXT = ".py" + (__debug__ and "c" or "o")
Expand Down Expand Up @@ -945,7 +944,8 @@ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:

def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]:
# This method handles the 'walrus operator' repr of the target
# name if it's a local variable or _should_repr_global_name() thinks it's acceptable.
# name if it's a local variable or _should_repr_global_name()
# thinks it's acceptable.
locs = ast.Call(self.builtin("locals"), [], [])
target_id = name.target.id # type: ignore[attr-defined]
inlocs = ast.Compare(ast.Str(target_id), [ast.In()], [locs])
Expand Down Expand Up @@ -981,25 +981,23 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
# cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
self.expl_stmts = fail_inner
if isinstance(v, ast.Compare):
if isinstance(v.left, namedExpr) and (
# Check if the left operand is a namedExpr and the value has already been visited
if (
isinstance(v, ast.Compare)
and isinstance(v.left, namedExpr)
and (
v.left.target.id
in [
ast_expr.id
for ast_expr in boolop.values[:i]
if hasattr(ast_expr, "id")
]
or v.left.target.id == pytest_temp
):
pytest_temp = f"pytest_{v.left.target.id}_temp"
self.variables_overwrite[v.left.target.id] = pytest_temp
v.left.target.id = pytest_temp

elif isinstance(v.left, ast.Name) and (
pytest_temp is not None
and v.left.id == pytest_temp.lstrip("pytest_").rstrip("_temp")
):
v.left.id = pytest_temp
)
):
pytest_temp = util.compose_temp_variable(v.left.target.id)
self.variables_overwrite[v.left.target.id] = pytest_temp
v.left.target.id = pytest_temp
self.push_format_context()
res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
Expand Down Expand Up @@ -1075,6 +1073,7 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:

def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
self.push_format_context()
# We first check if we have overwritten a variable in the previous assert
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
comp.left.id = self.variables_overwrite[comp.left.id]
left_res, left_expl = self.visit(comp.left)
Expand All @@ -1093,7 +1092,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
and isinstance(left_res, ast.Name)
and next_operand.target.id == left_res.id
):
next_operand.target.id = f"pytest_{left_res.id}_temp"
next_operand.target.id = util.compose_temp_variable(left_res.id)
self.variables_overwrite[left_res.id] = next_operand.target.id
next_res, next_expl = self.visit(next_operand)
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
Expand Down
4 changes: 4 additions & 0 deletions src/_pytest/assertion/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,3 +520,7 @@ def running_on_ci() -> bool:
"""Check if we're currently running on a CI system."""
env_vars = ["CI", "BUILD_NUMBER"]
return any(var in os.environ for var in env_vars)


def compose_temp_variable(original_variable: str) -> str:
return f"pytest_{original_variable}_temp"

0 comments on commit 0bc4bdc

Please sign in to comment.