Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further improve typing of builtins brain #2225

Merged
2 changes: 1 addition & 1 deletion astroid/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _unpack_keywords(
keywords: list[tuple[str | None, nodes.NodeNG]],
context: InferenceContext | None = None,
):
values = {}
values: dict[str | None, InferenceResult] = {}
context = context or InferenceContext()
context.extra_context = self.argument_context_map
for name, value in keywords:
Expand Down
52 changes: 38 additions & 14 deletions astroid/brain/brain_builtin_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from __future__ import annotations

import itertools
from collections.abc import Callable, Iterator
from collections.abc import Callable, Iterable
from functools import partial
from typing import Any, Type, Union, cast
from typing import TYPE_CHECKING, Any, Iterator, NoReturn, Type, Union, cast

from astroid import arguments, helpers, inference_tip, nodes, objects, util
from astroid.builder import AstroidBuilder
Expand All @@ -29,6 +29,9 @@
SuccessfulInferenceResult,
)

if TYPE_CHECKING:
from astroid.bases import Instance

ContainerObjects = Union[
objects.FrozenSet,
objects.DictItems,
Expand All @@ -43,6 +46,13 @@
Type[frozenset],
]

CopyResult = Union[
nodes.Dict,
nodes.List,
nodes.Set,
objects.FrozenSet,
]

OBJECT_DUNDER_NEW = "object.__new__"

STR_CLASS = """
Expand Down Expand Up @@ -127,6 +137,10 @@ def ljust(self, width, fillchar=None):
"""


def _use_default() -> NoReturn: # pragma: no cover
raise UseInferenceDefault()


def _extend_string_class(class_node, code, rvalue):
"""Function to extend builtin str/unicode class."""
code = code.format(rvalue=rvalue)
Expand Down Expand Up @@ -193,7 +207,9 @@ def register_builtin_transform(transform, builtin_name) -> None:
an optional context.
"""

def _transform_wrapper(node, context: InferenceContext | None = None):
def _transform_wrapper(
node: nodes.Call, context: InferenceContext | None = None, **kwargs: Any
) -> Iterator:
result = transform(node, context=context)
if result:
if not result.parent:
Expand Down Expand Up @@ -257,10 +273,12 @@ def _container_generic_transform(
iterables: tuple[type[nodes.BaseContainer] | type[ContainerObjects], ...],
build_elts: BuiltContainers,
) -> nodes.BaseContainer | None:
elts: Iterable | str | bytes

if isinstance(arg, klass):
return arg
if isinstance(arg, iterables):
arg = cast(ContainerObjects, arg)
arg = cast(Union[nodes.BaseContainer, ContainerObjects], arg)
if all(isinstance(elt, nodes.Const) for elt in arg.elts):
elts = [cast(nodes.Const, elt).value for elt in arg.elts]
else:
Expand All @@ -277,9 +295,10 @@ def _container_generic_transform(
elts.append(evaluated_object)
elif isinstance(arg, nodes.Dict):
# Dicts need to have consts as strings already.
if not all(isinstance(elt[0], nodes.Const) for elt in arg.items):
raise UseInferenceDefault()
elts = [item[0].value for item in arg.items]
elts = [
item[0].value if isinstance(item[0], nodes.Const) else _use_default()
for item in arg.items
]
elif isinstance(arg, nodes.Const) and isinstance(arg.value, (str, bytes)):
elts = arg.value
else:
Expand Down Expand Up @@ -399,6 +418,7 @@ def infer_dict(node: nodes.Call, context: InferenceContext | None = None) -> nod
args = call.positional_arguments
kwargs = list(call.keyword_arguments.items())

items: list[tuple[InferenceResult, InferenceResult]]
if not args and not kwargs:
# dict()
return nodes.Dict(
Expand Down Expand Up @@ -695,7 +715,9 @@ def infer_slice(node, context: InferenceContext | None = None):
return slice_node


def _infer_object__new__decorator(node, context: InferenceContext | None = None):
def _infer_object__new__decorator(
node: nodes.ClassDef, context: InferenceContext | None = None, **kwargs: Any
) -> Iterator[Instance]:
# Instantiate class immediately
# since that's what @object.__new__ does
return iter((node.instantiate_class(),))
Expand Down Expand Up @@ -944,10 +966,10 @@ def _build_dict_with_elements(elements):
if isinstance(inferred_values, nodes.Const) and isinstance(
inferred_values.value, (str, bytes)
):
elements = [
elements_with_value = [
(nodes.Const(element), default) for element in inferred_values.value
]
return _build_dict_with_elements(elements)
return _build_dict_with_elements(elements_with_value)
if isinstance(inferred_values, nodes.Dict):
keys = inferred_values.itered()
for key in keys:
Expand All @@ -964,7 +986,7 @@ def _build_dict_with_elements(elements):

def _infer_copy_method(
node: nodes.Call, context: InferenceContext | None = None, **kwargs: Any
) -> Iterator[InferenceResult]:
) -> Iterator[CopyResult]:
assert isinstance(node.func, nodes.Attribute)
inferred_orig, inferred_copy = itertools.tee(node.func.expr.infer(context=context))
if all(
Expand All @@ -973,9 +995,9 @@ def _infer_copy_method(
)
for inferred_node in inferred_orig
):
return inferred_copy
return cast(Iterator[CopyResult], inferred_copy)

raise UseInferenceDefault()
raise UseInferenceDefault


def _is_str_format_call(node: nodes.Call) -> bool:
Expand Down Expand Up @@ -1081,5 +1103,7 @@ def _infer_str_format_call(
)

AstroidManager().register_transform(
nodes.Call, inference_tip(_infer_str_format_call), _is_str_format_call
nodes.Call,
inference_tip(_infer_str_format_call),
_is_str_format_call,
)
12 changes: 6 additions & 6 deletions astroid/nodes/node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1865,9 +1865,7 @@ def __init__(
parent=parent,
)

def postinit(
self, items: list[tuple[SuccessfulInferenceResult, SuccessfulInferenceResult]]
) -> None:
def postinit(self, items: list[tuple[InferenceResult, InferenceResult]]) -> None:
"""Do some setup after initialisation.

:param items: The key-value pairs contained in the dictionary.
Expand Down Expand Up @@ -3911,11 +3909,13 @@ class EvaluatedObject(NodeNG):
_astroid_fields = ("original",)
_other_fields = ("value",)

def __init__(self, original: NodeNG, value: NodeNG | util.UninferableBase) -> None:
self.original: NodeNG = original
def __init__(
self, original: SuccessfulInferenceResult, value: InferenceResult
) -> None:
self.original: SuccessfulInferenceResult = original
"""The original node that has already been evaluated"""

self.value: NodeNG | util.UninferableBase = value
self.value: InferenceResult = value
"""The inferred value"""

super().__init__(
Expand Down
4 changes: 2 additions & 2 deletions astroid/rebuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from astroid.manager import AstroidManager
from astroid.nodes import NodeNG
from astroid.nodes.utils import Position
from astroid.typing import SuccessfulInferenceResult
from astroid.typing import InferenceResult

REDIRECT: Final[dict[str, str]] = {
"arguments": "Arguments",
Expand Down Expand Up @@ -994,7 +994,7 @@ def visit_dict(self, node: ast.Dict, parent: NodeNG) -> nodes.Dict:
end_col_offset=node.end_col_offset,
parent=parent,
)
items: list[tuple[SuccessfulInferenceResult, SuccessfulInferenceResult]] = list(
items: list[tuple[InferenceResult, InferenceResult]] = list(
self._visit_dict_items(node, parent, newnode)
)
newnode.postinit(items)
Expand Down
4 changes: 3 additions & 1 deletion astroid/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
)

if TYPE_CHECKING:
from collections.abc import Iterator

from astroid import bases, exceptions, nodes, transforms, util
from astroid.context import InferenceContext
from astroid.interpreter._import import spec
Expand Down Expand Up @@ -84,7 +86,7 @@ def __call__(
node: _SuccessfulInferenceResultT_contra,
context: InferenceContext | None = None,
**kwargs: Any,
) -> Generator[InferenceResult, None, None]:
) -> Iterator[InferenceResult]:
... # pragma: no cover


Expand Down
5 changes: 4 additions & 1 deletion tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1904,7 +1904,10 @@ def test_str_repr_no_warnings(node):

if "int" in param_type.annotation:
args[name] = random.randint(0, 50)
elif "NodeNG" in param_type.annotation:
elif (
"NodeNG" in param_type.annotation
or "SuccessfulInferenceResult" in param_type.annotation
):
args[name] = nodes.Unknown()
elif "str" in param_type.annotation:
args[name] = ""
Expand Down
Loading