Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve consistency of JoinedStr inference #2622

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ Release date: TBA
Closes #2521
Closes #2523

* Improve consistency of ``JoinedStr`` inference by not raising ``InferenceError`` and
returning either ``Uninferable`` or a fully resolved ``Const``.

* Introduces flag ``JoinedStr.FAIL_ON_UNINFERABLE`` which defaults to ``True``, but can be set to ```False``,
thus allowing return of partially inferred strings, for example "a/{MISSING_VALUE}/b" when
inferring ``f"a/{missing}/b"`` with ``missing`` being uninferable

Closes #2621

* Fix crash when typing._alias() call is missing arguments.

Closes #2513
Expand Down
52 changes: 39 additions & 13 deletions astroid/nodes/node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4715,7 +4715,7 @@ def _infer(
continue


MISSING_VALUE = "{MISSING_VALUE}"
UNINFERABLE_VALUE = "{Uninferable}"


class JoinedStr(NodeNG):
Expand Down Expand Up @@ -4781,33 +4781,59 @@ def get_children(self):
def _infer(
self, context: InferenceContext | None = None, **kwargs: Any
) -> Generator[InferenceResult, None, InferenceErrorInfo | None]:
yield from self._infer_from_values(self.values, context)
if self.values:
yield from self._infer_with_values(context)
else:
yield Const("")

def _infer_with_values(
self, context: InferenceContext | None = None, **kwargs: Any
) -> Generator[InferenceResult, None, InferenceErrorInfo | None]:
uninferable_already_generated = False
for inferred in self._infer_from_values(self.values, context):
failed = (
inferred is util.Uninferable
or isinstance(inferred, Const)
and UNINFERABLE_VALUE in inferred.value
)
if failed:
if not uninferable_already_generated:
uninferable_already_generated = True
yield util.Uninferable
continue
yield inferred

@classmethod
def _infer_from_values(
cls, nodes: list[NodeNG], context: InferenceContext | None = None, **kwargs: Any
) -> Generator[InferenceResult, None, InferenceErrorInfo | None]:
if not nodes:
yield
return
if len(nodes) == 1:
yield from nodes[0]._infer(context, **kwargs)
for node in cls._safe_infer_from_node(nodes[0], context, **kwargs):
if isinstance(node, Const):
yield node
continue
yield Const(UNINFERABLE_VALUE)
return
uninferable_already_generated = False
for prefix in nodes[0]._infer(context, **kwargs):
for prefix in cls._safe_infer_from_node(nodes[0], context, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooh, good spot. If we isolate this change I bet we could get quick consensus to merge it 👍

for suffix in cls._infer_from_values(nodes[1:], context, **kwargs):
result = ""
for node in (prefix, suffix):
if isinstance(node, Const):
result += str(node.value)
continue
result += MISSING_VALUE
if MISSING_VALUE in result:
if not uninferable_already_generated:
uninferable_already_generated = True
yield util.Uninferable
else:
yield Const(result)
result += UNINFERABLE_VALUE
yield Const(result)

@classmethod
def _safe_infer_from_node(
cls, node: NodeNG, context: InferenceContext | None = None, **kwargs: Any
) -> Generator[InferenceResult, None, InferenceErrorInfo | None]:
try:
yield from node._infer(context, **kwargs)
except InferenceError:
yield util.Uninferable


class NamedExpr(_base_nodes.AssignTypeNode):
Expand Down
14 changes: 11 additions & 3 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7402,7 +7402,12 @@ class Cls:
""",
"<__main__.Cls",
),
("s1 = f'{5}' #@", "5"),
(
"s1 = f'{5}' #@",
"5",
),
("s1 = f'{missing}'", None),
("s1 = f'a/{missing}/b'", None),
],
)
def test_joined_str_returns_string(source, expected) -> None:
Expand All @@ -7413,5 +7418,8 @@ def test_joined_str_returns_string(source, expected) -> None:
assert target
inferred = list(target.inferred())
assert len(inferred) == 1
assert isinstance(inferred[0], Const)
inferred[0].value.startswith(expected)
if expected:
assert isinstance(inferred[0], Const)
inferred[0].value.startswith(expected)
else:
assert inferred[0] is Uninferable
Loading