Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix different behavior with unittest when warlus operator #10758

Merged
merged 13 commits into from
Mar 10, 2023
Merged
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Adam Uhlir
Ahn Ki-Wook
Akiomi Kamakura
Alan Velasco
Alessio Izzo
Alexander Johnson
Alexander King
Alexei Kozlenok
Expand Down
1 change: 1 addition & 0 deletions changelog/10743.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The assertion rewriting mechanism now works correctly when assertion expressions contain the walrus operator.
54 changes: 50 additions & 4 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,13 @@
if TYPE_CHECKING:
from _pytest.assertion import AssertionState

if sys.version_info >= (3, 8):
namedExpr = ast.NamedExpr
else:
namedExpr = ast.Expr

assertstate_key = StashKey["AssertionState"]()

assertstate_key = StashKey["AssertionState"]()

# pytest caches rewritten pycs in pycache dirs
PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
Expand Down Expand Up @@ -635,8 +639,12 @@ class AssertionRewriter(ast.NodeVisitor):
.push_format_context() and .pop_format_context() which allows
to build another %-formatted string while already building one.

This state is reset on every new assert statement visited and used
by the other visitors.
:variables_overwrite: A dict filled with references to variables
that change value within an assert. This happens when a variable is
reassigned with the walrus operator

This state, except the variables_overwrite, is reset on every new assert
statement visited and used by the other visitors.
"""

def __init__(
Expand All @@ -652,6 +660,7 @@ def __init__(
else:
self.enable_assertion_pass_hook = False
self.source = source
self.variables_overwrite: Dict[str, str] = {}

def run(self, mod: ast.Module) -> None:
"""Find all assert statements in *mod* and rewrite them."""
Expand All @@ -666,7 +675,7 @@ def run(self, mod: ast.Module) -> None:
if doc is not None and self.is_rewrite_disabled(doc):
return
pos = 0
lineno = 1
item = None
for item in mod.body:
if (
expect_docstring
Expand Down Expand Up @@ -937,6 +946,18 @@ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
ast.copy_location(node, assert_)
return self.statements

def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]:
# This method handles the 'walrus operator' repr of the target
# name if it's a local variable or _should_repr_global_name()
# thinks it's acceptable.
locs = ast.Call(self.builtin("locals"), [], [])
target_id = name.target.id # type: ignore[attr-defined]
inlocs = ast.Compare(ast.Str(target_id), [ast.In()], [locs])
dorepr = self.helper("_should_repr_global_name", name)
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
expr = ast.IfExp(test, self.display(name), ast.Str(target_id))
return name, self.explanation_param(expr)

def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
# Display the repr of the name if it's a local variable or
# _should_repr_global_name() thinks it's acceptable.
Expand All @@ -963,6 +984,20 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
# cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
self.expl_stmts = fail_inner
# Check if the left operand is a namedExpr and the value has already been visited
if (
isinstance(v, ast.Compare)
and isinstance(v.left, namedExpr)
and v.left.target.id
in [
ast_expr.id
for ast_expr in boolop.values[:i]
if hasattr(ast_expr, "id")
]
):
pytest_temp = self.variable()
self.variables_overwrite[v.left.target.id] = pytest_temp
v.left.target.id = pytest_temp
self.push_format_context()
res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
Expand Down Expand Up @@ -1038,6 +1073,9 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:

def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
self.push_format_context()
# We first check if we have overwritten a variable in the previous assert
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
comp.left.id = self.variables_overwrite[comp.left.id]
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
left_expl = f"({left_expl})"
Expand All @@ -1049,6 +1087,13 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
syms = []
results = [left_res]
for i, op, next_operand in it:
if (
isinstance(next_operand, namedExpr)
and isinstance(left_res, ast.Name)
and next_operand.target.id == left_res.id
):
next_operand.target.id = self.variable()
self.variables_overwrite[left_res.id] = next_operand.target.id
next_res, next_expl = self.visit(next_operand)
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
next_expl = f"({next_expl})"
Expand All @@ -1072,6 +1117,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
res: ast.expr = ast.BoolOp(ast.And(), load_names)
else:
res = load_names[0]

return res, self.explanation_param(self.pop_format_context(expl_call))


Expand Down
171 changes: 171 additions & 0 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,177 @@ def test_simple_failure():
result.stdout.fnmatch_lines(["*E*assert (1 + 1) == 3"])


@pytest.mark.skipif(
sys.version_info < (3, 8), reason="walrus operator not available in py<38"
)
class TestIssue10743:
def test_assertion_walrus_operator(self, pytester: Pytester) -> None:
pytester.makepyfile(
"""
def my_func(before, after):
return before == after

def change_value(value):
return value.lower()

def test_walrus_conversion():
a = "Hello"
assert not my_func(a, a := change_value(a))
assert a == "hello"
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_walrus_operator_dont_rewrite(self, pytester: Pytester) -> None:
pytester.makepyfile(
"""
'PYTEST_DONT_REWRITE'
def my_func(before, after):
return before == after

def change_value(value):
return value.lower()

def test_walrus_conversion_dont_rewrite():
a = "Hello"
assert not my_func(a, a := change_value(a))
assert a == "hello"
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_inline_walrus_operator(self, pytester: Pytester) -> None:
pytester.makepyfile(
"""
def my_func(before, after):
return before == after

def test_walrus_conversion_inline():
a = "Hello"
assert not my_func(a, a := a.lower())
assert a == "hello"
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_inline_walrus_operator_reverse(self, pytester: Pytester) -> None:
pytester.makepyfile(
"""
def my_func(before, after):
return before == after

def test_walrus_conversion_reverse():
a = "Hello"
assert my_func(a := a.lower(), a)
assert a == 'hello'
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_walrus_no_variable_name_conflict(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def test_walrus_conversion_no_conflict():
a = "Hello"
assert a == (b := a.lower())
"""
)
result = pytester.runpytest()
assert result.ret == 1
result.stdout.fnmatch_lines(["*AssertionError: assert 'Hello' == 'hello'"])

def test_assertion_walrus_operator_true_assertion_and_changes_variable_value(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def test_walrus_conversion_succeed():
a = "Hello"
assert a != (a := a.lower())
assert a == 'hello'
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_walrus_operator_fail_assertion(self, pytester: Pytester) -> None:
pytester.makepyfile(
"""
def test_walrus_conversion_fails():
a = "Hello"
assert a == (a := a.lower())
"""
)
result = pytester.runpytest()
assert result.ret == 1
result.stdout.fnmatch_lines(["*AssertionError: assert 'Hello' == 'hello'"])

def test_assertion_walrus_operator_boolean_composite(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def test_walrus_operator_change_boolean_value():
a = True
assert a and True and ((a := False) is False) and (a is False) and ((a := None) is None)
assert a is None
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_walrus_operator_compare_boolean_fails(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def test_walrus_operator_change_boolean_value():
a = True
assert not (a and ((a := False) is False))
"""
)
result = pytester.runpytest()
assert result.ret == 1
result.stdout.fnmatch_lines(["*assert not (True and False is False)"])

def test_assertion_walrus_operator_boolean_none_fails(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def test_walrus_operator_change_boolean_value():
a = True
assert not (a and ((a := None) is None))
"""
)
result = pytester.runpytest()
assert result.ret == 1
result.stdout.fnmatch_lines(["*assert not (True and None is None)"])

def test_assertion_walrus_operator_value_changes_cleared_after_each_test(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def test_walrus_operator_change_value():
a = True
assert (a := None) is None

def test_walrus_operator_not_override_value():
a = True
assert a is True
"""
)
result = pytester.runpytest()
assert result.ret == 0


@pytest.mark.skipif(
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
)
Expand Down