Skip to content

Commit

Permalink
Add ParamSpec node
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobtylerwalls committed Jun 25, 2023
1 parent b22b854 commit e977d97
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 3 deletions.
3 changes: 3 additions & 0 deletions astroid/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
NamedExpr,
NodeNG,
Nonlocal,
ParamSpec,
Pass,
Pattern,
Raise,
Expand Down Expand Up @@ -182,6 +183,7 @@
NamedExpr,
NodeNG,
Nonlocal,
ParamSpec,
Pass,
Pattern,
Raise,
Expand Down Expand Up @@ -275,6 +277,7 @@
"NamedExpr",
"NodeNG",
"Nonlocal",
"ParamSpec",
"Pass",
"Position",
"Raise",
Expand Down
4 changes: 4 additions & 0 deletions astroid/nodes/as_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ def visit_nonlocal(self, node) -> str:
"""return an astroid.Nonlocal node as string"""
return f"nonlocal {', '.join(node.names)}"

def visit_paramspec(self, node: nodes.ParamSpec) -> str:
"""return an astroid.ParamSpec node as string"""
return node.name

def visit_pass(self, node) -> str:
"""return an astroid.Pass node as string"""
return "pass"
Expand Down
50 changes: 48 additions & 2 deletions astroid/nodes/node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2695,6 +2695,52 @@ def _infer_name(self, frame, name):
return name


class ParamSpec(_base_nodes.AssignTypeNode):
"""Class representing a :class:`ast.ParamSpec` node.
>>> import astroid
>>> node = astroid.extract_node('type Alias[**P] = Callable[P, int]')
>>> node.type_params[0]
<ParamSpec l.1 at 0x7f23b2e4e198>
"""

def __init__(
self,
lineno: int | None = None,
col_offset: int | None = None,
parent: NodeNG | None = None,
*,
end_lineno: int | None = None,
end_col_offset: int | None = None,
) -> None:
self.name: str
super().__init__(
lineno=lineno,
col_offset=col_offset,
end_lineno=end_lineno,
end_col_offset=end_col_offset,
parent=parent,
)

def postinit(self, name: str) -> None:
self.name = name

assigned_stmts: ClassVar[
Callable[
[
ParamSpec,
AssignName,
InferenceContext | None,
None,
],
Generator[NodeNG, None, None],
]
]
"""Returns the assigned statement (non inferred) according to the assignment type.
See astroid/protocols.py for actual implementation.
"""


class Pass(_base_nodes.NoChildrenNode, _base_nodes.Statement):
"""Class representing an :class:`ast.Pass` node.
Expand Down Expand Up @@ -3329,7 +3375,7 @@ def __init__(
end_lineno: int | None = None,
end_col_offset: int | None = None,
) -> None:
self.type_params: list[TypeVar]
self.type_params: list[TypeVar, ParamSpec]
self.value: NodeNG
super().__init__(
lineno=lineno,
Expand All @@ -3342,7 +3388,7 @@ def __init__(
def postinit(
self,
*,
type_params: list[TypeVar],
type_params: list[TypeVar, ParamSpec],
value: NodeNG,
) -> None:
self.type_params = type_params
Expand Down
18 changes: 18 additions & 0 deletions astroid/rebuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,12 @@ def visit(self, node: ast.Nonlocal, parent: NodeNG) -> nodes.Nonlocal:
def visit(self, node: ast.Constant, parent: NodeNG) -> nodes.Const:
...

if sys.version_info >= (3, 12):

@overload
def visit(self, node: ast.ParamSpec, parent: NodeNG) -> nodes.ParamSpec:
...

@overload
def visit(self, node: ast.Pass, parent: NodeNG) -> nodes.Pass:
...
Expand Down Expand Up @@ -1493,6 +1499,18 @@ def visit_constant(self, node: ast.Constant, parent: NodeNG) -> nodes.Const:
parent=parent,
)

def visit_paramspec(self, node: ast.ParamSpec, parent: NodeNG) -> nodes.ParamSpec:
"""Visit a ParamSpec node by returning a fresh instance of it."""
newnode = nodes.ParamSpec(
lineno=node.lineno,
col_offset=node.col_offset,
end_lineno=node.end_lineno,
end_col_offset=node.end_col_offset,
parent=parent,
)
newnode.postinit(node.name)
return newnode

def visit_pass(self, node: ast.Pass, parent: NodeNG) -> nodes.Pass:
"""Visit a Pass node by returning a fresh instance of it."""
return nodes.Pass(
Expand Down
3 changes: 3 additions & 0 deletions doc/api/astroid.nodes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Nodes
astroid.nodes.Module
astroid.nodes.Name
astroid.nodes.Nonlocal
astroid.nodes.ParamSpec
astroid.nodes.Pass
astroid.nodes.Raise
astroid.nodes.Return
Expand Down Expand Up @@ -204,6 +205,8 @@ Nodes

.. autoclass:: astroid.nodes.Nonlocal

.. autoclass:: astroid.nodes.ParamSpec

.. autoclass:: astroid.nodes.Pass

.. autoclass:: astroid.nodes.Raise
Expand Down
10 changes: 9 additions & 1 deletion tests/test_type_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from astroid import extract_node
from astroid.const import PY312_PLUS
from astroid.nodes import Subscript, TypeAlias, TypeVar
from astroid.nodes import ParamSpec, Subscript, TypeAlias, TypeVar


@pytest.mark.skipif(not PY312_PLUS, reason="Requires Python 3.12 or higher")
Expand All @@ -23,6 +23,14 @@ def test_type_alias() -> None:
assert all(elt.name == "float" for elt in node.value.slice.elts)


@pytest.mark.skipif(not PY312_PLUS, reason="Requires Python 3.12 or higher")
def test_type_param_spec() -> None:
node = extract_node("type Alias[**P] = Callable[P, int]")
params = node.type_params[0]
assert isinstance(params, ParamSpec)
assert params.name == "P"


@pytest.mark.skipif(not PY312_PLUS, reason="Requires Python 3.12 or higher")
def test_type_param() -> None:
func_node = extract_node("def func[T]() -> T: ...")
Expand Down

0 comments on commit e977d97

Please sign in to comment.