Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

11028 - Fix warlus operator behavior when called by a function #11041

Merged
merged 5 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/11028.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed bug in assertion rewriting where a variable assigned with the walrus operator could not be used later in a function call.
25 changes: 22 additions & 3 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,9 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
]
):
pytest_temp = self.variable()
self.variables_overwrite[v.left.target.id] = pytest_temp
self.variables_overwrite[
v.left.target.id
] = v.left # type:ignore[assignment]
v.left.target.id = pytest_temp
self.push_format_context()
res, expl = self.visit(v)
Expand Down Expand Up @@ -1037,10 +1039,19 @@ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
new_args = []
new_kwargs = []
for arg in call.args:
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite:
arg = self.variables_overwrite[arg.id] # type:ignore[assignment]
res, expl = self.visit(arg)
arg_expls.append(expl)
new_args.append(res)
for keyword in call.keywords:
if (
isinstance(keyword.value, ast.Name)
and keyword.value.id in self.variables_overwrite
):
keyword.value = self.variables_overwrite[
keyword.value.id
] # type:ignore[assignment]
res, expl = self.visit(keyword.value)
new_kwargs.append(ast.keyword(keyword.arg, res))
if keyword.arg:
Expand Down Expand Up @@ -1075,7 +1086,13 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
self.push_format_context()
# We first check if we have overwritten a variable in the previous assert
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
comp.left.id = self.variables_overwrite[comp.left.id]
comp.left = self.variables_overwrite[
comp.left.id
] # type:ignore[assignment]
if isinstance(comp.left, namedExpr):
self.variables_overwrite[
comp.left.target.id
] = comp.left # type:ignore[assignment]
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
left_expl = f"({left_expl})"
Expand All @@ -1093,7 +1110,9 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
and next_operand.target.id == left_res.id
):
next_operand.target.id = self.variable()
self.variables_overwrite[left_res.id] = next_operand.target.id
self.variables_overwrite[
left_res.id
] = next_operand # type:ignore[assignment]
next_res, next_expl = self.visit(next_operand)
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
next_expl = f"({next_expl})"
Expand Down
90 changes: 90 additions & 0 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,96 @@ def test_walrus_operator_not_override_value():
assert result.ret == 0


@pytest.mark.skipif(
sys.version_info < (3, 8), reason="walrus operator not available in py<38"
)
class TestIssue11028:
def test_assertion_walrus_operator_in_operand(self, pytester: Pytester) -> None:
pytester.makepyfile(
"""
def test_in_string():
assert (obj := "foo") in obj
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_walrus_operator_in_operand_json_dumps(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
import json

def test_json_encoder():
assert (obj := "foo") in json.dumps(obj)
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_walrus_operator_equals_operand_function(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def f(a):
return a

def test_call_other_function_arg():
assert (obj := "foo") == f(obj)
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_walrus_operator_equals_operand_function_keyword_arg(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def f(a='test'):
return a

def test_call_other_function_k_arg():
assert (obj := "foo") == f(a=obj)
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_walrus_operator_equals_operand_function_arg_as_function(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def f(a='test'):
return a

def test_function_of_function():
assert (obj := "foo") == f(f(obj))
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_walrus_operator_gt_operand_function(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def add_one(a):
return a + 1

def test_gt():
assert (obj := 4) > add_one(obj)
"""
)
result = pytester.runpytest()
assert result.ret == 1
result.stdout.fnmatch_lines(["*assert 4 > 5", "*where 5 = add_one(4)"])


@pytest.mark.skipif(
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
)
Expand Down