Skip to content

Commit

Permalink
Add 'default' field to pytd.TypeParameter.
Browse files Browse the repository at this point in the history
I did a few manual checks, and I believe this can be submitted without a
corresponding release without breaking anything. We might as well give it a try
and see how it goes.

For #1597.

PiperOrigin-RevId: 617071069
  • Loading branch information
rchen152 committed Mar 19, 2024
1 parent b899682 commit bf32914
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 10 deletions.
8 changes: 5 additions & 3 deletions pytype/pyi/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,11 +468,13 @@ def add_type_variable(self, name, tvar):
raise _ParseError(f"{tvar.kind} name needs to be {tvar.name!r} "
f"(not {name!r})")
bound = tvar.bound
if isinstance(bound, str):
bound = pytd.NamedType(bound)
constraints = tuple(tvar.constraints) if tvar.constraints else ()
if isinstance(tvar.default, list):
default = tuple(tvar.default)
else:
default = tvar.default
self.type_params.append(pytd_type(
name=name, constraints=constraints, bound=bound))
name=name, constraints=constraints, bound=bound, default=default))

def add_import(self, from_package, import_list):
"""Add an import.
Expand Down
16 changes: 11 additions & 5 deletions pytype/pyi/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re
import sys
import tokenize
from typing import Any, List, Optional, Tuple, cast
from typing import Any, List, Optional, Tuple, Union, cast

from pytype.ast import debug
from pytype.pyi import conditions
Expand Down Expand Up @@ -61,6 +61,7 @@ class _TypeVariable:
name: str
bound: Optional[pytd.Type]
constraints: List[pytd.Type]
default: Optional[Union[pytd.Type, List[pytd.Type]]]

@classmethod
def from_call(cls, kind: str, node: astlib.Call):
Expand All @@ -72,18 +73,20 @@ def from_call(cls, kind: str, node: astlib.Call):
if not types.Pyval.is_str(name):
raise ParseError(f"Bad arguments to {kind}")
bound = None
# 'bound' is the only keyword argument we currently use.
default = None
# TODO(rechen): We should enforce the PEP 484 guideline that
# len(constraints) != 1. However, this guideline is currently violated
# in typeshed (see https://github.com/python/typeshed/pull/806).
kws = {x.arg for x in node.keywords}
extra = kws - {"bound", "covariant", "contravariant"}
extra = kws - {"bound", "covariant", "contravariant", "default"}
if extra:
raise ParseError(f"Unrecognized keyword(s): {', '.join(extra)}")
for kw in node.keywords:
if kw.arg == "bound":
bound = kw.value
return cls(kind, name.value, bound, constraints)
elif kw.arg == "default":
default = kw.value
return cls(kind, name.value, bound, constraints, default)

#------------------------------------------------------
# Main tree visitor and generator code
Expand Down Expand Up @@ -674,6 +677,8 @@ def _convert_typevar_args(self, node: astlib.Call):
for kw in node.keywords:
if kw.arg == "bound":
kw.value = self.annotation_visitor.visit(kw.value)
elif kw.arg == "default":
kw.value = self.annotation_visitor.visit(kw.value)

def _convert_typed_dict_args(self, node: astlib.Call):
for fields in node.args[1:]:
Expand All @@ -682,7 +687,8 @@ def _convert_typed_dict_args(self, node: astlib.Call):
def enter_Call(self, node):
node.func = self.annotation_visitor.visit(node.func)
func = node.func.name or ""
if self.defs.matches_type(func, ("typing.TypeVar", "typing.ParamSpec")):
if self.defs.matches_type(func, ("typing.TypeVar", "typing.ParamSpec",
"typing.TypeVarTuple")):
self._convert_typevar_args(node)
elif self.defs.matches_type(func, "typing.NamedTuple"):
self._convert_typing_namedtuple_args(node)
Expand Down
25 changes: 25 additions & 0 deletions pytype/pyi/parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3526,5 +3526,30 @@ def f(x: Tuple[Any, ...]) -> Any: ...
""")


class TypeParameterDefaultTest(parser_test_base.ParserTestBase):

def test_typevar(self):
self.check("""
from typing_extensions import TypeVar
T = TypeVar('T', default=int)
""")

def test_paramspec(self):
self.check("""
from typing_extensions import ParamSpec
P = ParamSpec('P', default=[str, int])
""")

def test_typevartuple(self):
self.check("""
from typing_extensions import TypeVarTuple, Unpack
Ts = TypeVarTuple('Ts', default=Unpack[tuple[str, int]])
""", """
from typing_extensions import TypeVarTuple, TypeVarTuple as Ts, Unpack
""")


if __name__ == "__main__":
unittest.main()
5 changes: 5 additions & 0 deletions pytype/pytd/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ def _FormatTypeParams(self, type_params):
args += [self.Print(c) for c in t.constraints]
if t.bound:
args.append(f"bound={self.Print(t.bound)}")
if isinstance(t.default, tuple):
args.append(
f"default=[{', '.join(self.Print(d) for d in t.default)}]")
elif t.default:
args.append(f"default={self.Print(t.default)}")
if isinstance(t, pytd.ParamSpec):
typename = self._LookupTypingMember("ParamSpec")
else:
Expand Down
1 change: 1 addition & 0 deletions pytype/pytd/pytd.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def f(x: T) -> T
name: str
constraints: Tuple[TypeU, ...] = ()
bound: Optional[TypeU] = None
default: Optional[Union[TypeU, Tuple[TypeU, ...]]] = None
scope: Optional[str] = None

def __lt__(self, other):
Expand Down
3 changes: 1 addition & 2 deletions pytype/pytd/visitors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,8 +552,7 @@ class A(Dict[T, T], Generic[T]): pass
""")
a = ast.Lookup("A")
self.assertEqual(
(pytd.TemplateItem(pytd.TypeParameter("T", (), None, "A")),),
a.template)
(pytd.TemplateItem(pytd.TypeParameter("T", scope="A")),), a.template)

def test_adjust_type_parameters_with_duplicates_in_generic(self):
src = textwrap.dedent("""
Expand Down

0 comments on commit bf32914

Please sign in to comment.