Skip to content

Commit

Permalink
Improve typing of builtins brain (#2214)
Browse files Browse the repository at this point in the history
Resolves 14 mypy errors
  • Loading branch information
jacobtylerwalls authored Jun 20, 2023
1 parent 525c3b2 commit 842548d
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 29 deletions.
6 changes: 5 additions & 1 deletion astroid/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ def has_invalid_keywords(self) -> bool:
"""
return len(self.keyword_arguments) != len(self._unpacked_kwargs)

def _unpack_keywords(self, keywords, context: InferenceContext | None = None):
def _unpack_keywords(
self,
keywords: list[tuple[str | None, nodes.NodeNG]],
context: InferenceContext | None = None,
):
values = {}
context = context or InferenceContext()
context.extra_context = self.argument_context_map
Expand Down
86 changes: 58 additions & 28 deletions astroid/brain/brain_builtin_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from __future__ import annotations

import itertools
from collections.abc import Callable, Iterable, Iterator
from collections.abc import Callable, Iterator
from functools import partial
from typing import Any
from typing import Any, Type, Union, cast

from astroid import arguments, bases, helpers, inference_tip, nodes, objects, util
from astroid import arguments, helpers, inference_tip, nodes, objects, util
from astroid.builder import AstroidBuilder
from astroid.context import InferenceContext
from astroid.exceptions import (
Expand All @@ -23,7 +23,25 @@
)
from astroid.manager import AstroidManager
from astroid.nodes import scoped_nodes
from astroid.typing import InferenceResult, SuccessfulInferenceResult
from astroid.typing import (
ConstFactoryResult,
InferenceResult,
SuccessfulInferenceResult,
)

ContainerObjects = Union[
objects.FrozenSet,
objects.DictItems,
objects.DictKeys,
objects.DictValues,
]

BuiltContainers = Union[
Type[tuple],
Type[list],
Type[set],
Type[frozenset],
]

OBJECT_DUNDER_NEW = "object.__new__"

Expand Down Expand Up @@ -232,18 +250,19 @@ def _container_generic_inference(
return transformed


def _container_generic_transform( # pylint: disable=inconsistent-return-statements
def _container_generic_transform(
arg: SuccessfulInferenceResult,
context: InferenceContext | None,
klass: type[nodes.BaseContainer],
iterables: tuple[type[nodes.NodeNG] | type[bases.Proxy], ...],
build_elts: type[Iterable[Any]],
iterables: tuple[type[nodes.BaseContainer] | type[ContainerObjects], ...],
build_elts: BuiltContainers,
) -> nodes.BaseContainer | None:
if isinstance(arg, klass):
return arg
if isinstance(arg, iterables):
arg = cast(ContainerObjects, arg)
if all(isinstance(elt, nodes.Const) for elt in arg.elts):
elts = [elt.value for elt in arg.elts]
elts = [cast(nodes.Const, elt).value for elt in arg.elts]
else:
# TODO: Does not handle deduplication for sets.
elts = []
Expand All @@ -264,16 +283,16 @@ def _container_generic_transform( # pylint: disable=inconsistent-return-stateme
elif isinstance(arg, nodes.Const) and isinstance(arg.value, (str, bytes)):
elts = arg.value
else:
return
return None
return klass.from_elements(elts=build_elts(elts))


def _infer_builtin_container(
node: nodes.Call,
context: InferenceContext | None,
klass: type[nodes.BaseContainer],
iterables: tuple[type[nodes.NodeNG] | type[bases.Proxy], ...],
build_elts: type[Iterable[Any]],
iterables: tuple[type[nodes.NodeNG] | type[ContainerObjects], ...],
build_elts: BuiltContainers,
) -> nodes.BaseContainer:
transform_func = partial(
_container_generic_transform,
Expand Down Expand Up @@ -944,8 +963,8 @@ def _build_dict_with_elements(elements):


def _infer_copy_method(
node: nodes.Call, context: InferenceContext | None = None
) -> Iterator[nodes.NodeNG]:
node: nodes.Call, context: InferenceContext | None = None, **kwargs: Any
) -> Iterator[InferenceResult]:
assert isinstance(node.func, nodes.Attribute)
inferred_orig, inferred_copy = itertools.tee(node.func.expr.infer(context=context))
if all(
Expand Down Expand Up @@ -973,33 +992,44 @@ def _is_str_format_call(node: nodes.Call) -> bool:


def _infer_str_format_call(
node: nodes.Call, context: InferenceContext | None = None
) -> Iterator[nodes.Const | util.UninferableBase]:
node: nodes.Call, context: InferenceContext | None = None, **kwargs: Any
) -> Iterator[ConstFactoryResult | util.UninferableBase]:
"""Return a Const node based on the template and passed arguments."""
call = arguments.CallSite.from_call(node, context=context)
assert isinstance(node.func, (nodes.Attribute, nodes.AssignAttr, nodes.DelAttr))

value: nodes.Const
if isinstance(node.func.expr, nodes.Name):
value: nodes.Const | None = helpers.safe_infer(node.func.expr)
if value is None:
if not (inferred := helpers.safe_infer(node.func.expr)) or not isinstance(
inferred, nodes.Const
):
return iter([util.Uninferable])
else:
value = inferred
elif isinstance(node.func.expr, nodes.Const):
value = node.func.expr
else: # pragma: no cover
return iter([util.Uninferable])

format_template = value.value

# Get the positional arguments passed
inferred_positional = [
helpers.safe_infer(i, context) for i in call.positional_arguments
]
if not all(isinstance(i, nodes.Const) for i in inferred_positional):
return iter([util.Uninferable])
inferred_positional: list[nodes.Const] = []
for i in call.positional_arguments:
one_inferred = helpers.safe_infer(i, context)
if not isinstance(one_inferred, nodes.Const):
return iter([util.Uninferable])
inferred_positional.append(one_inferred)

pos_values: list[str] = [i.value for i in inferred_positional]

# Get the keyword arguments passed
inferred_keyword = {
k: helpers.safe_infer(v, context) for k, v in call.keyword_arguments.items()
}
if not all(isinstance(i, nodes.Const) for i in inferred_keyword.values()):
return iter([util.Uninferable])
inferred_keyword: dict[str, nodes.Const] = {}
for k, v in call.keyword_arguments.items():
one_inferred = helpers.safe_infer(v, context)
if not isinstance(one_inferred, nodes.Const):
return iter([util.Uninferable])
inferred_keyword[k] = one_inferred

keyword_values: dict[str, str] = {k: v.value for k, v in inferred_keyword.items()}

try:
Expand Down

0 comments on commit 842548d

Please sign in to comment.