From 6c1d8671ce6eaf2c955fa986cbad51d6e6726d5d Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 24 Jun 2024 20:57:29 +0100 Subject: [PATCH] Fix ParamSpec inference against TypeVarTuple (#17431) Fixes https://github.com/python/mypy/issues/17278 Fixes https://github.com/python/mypy/issues/17127 --- mypy/constraints.py | 6 ++- mypy/expandtype.py | 14 ++++++- mypy/semanal_typeargs.py | 12 +----- mypy/types.py | 13 +++++- test-data/unit/check-typevar-tuple.test | 53 +++++++++++++++++++++++++ 5 files changed, 85 insertions(+), 13 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 316f481ac870..49a2aea8fa05 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -1071,7 +1071,11 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: # (with literal '...'). if not template.is_ellipsis_args: unpack_present = find_unpack_in_list(template.arg_types) - if unpack_present is not None: + # When both ParamSpec and TypeVarTuple are present, things become messy + # quickly. For now, we only allow ParamSpec to "capture" TypeVarTuple, + # but not vice versa. + # TODO: infer more from prefixes when possible. + if unpack_present is not None and not cactual.param_spec(): # We need to re-normalize args to the form they appear in tuples, # for callables we always pack the suffix inside another tuple. unpack = template.arg_types[unpack_present] diff --git a/mypy/expandtype.py b/mypy/expandtype.py index bff23c53defd..5c4d6af9458e 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -270,6 +270,13 @@ def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type: repl = self.variables.get(t.id, t) if isinstance(repl, TypeVarTupleType): return repl + elif isinstance(repl, ProperType) and isinstance(repl, (AnyType, UninhabitedType)): + # Some failed inference scenarios will try to set all type variables to Never. + # Instead of being picky and require all the callers to wrap them, + # do this here instead. + # Note: most cases when this happens are handled in expand unpack below, but + # in rare cases (e.g. ParamSpec containing Unpack star args) it may be skipped. + return t.tuple_fallback.copy_modified(args=[repl]) raise NotImplementedError def visit_unpack_type(self, t: UnpackType) -> Type: @@ -348,7 +355,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType: # the replacement is ignored. if isinstance(repl, Parameters): # We need to expand both the types in the prefix and the ParamSpec itself - return t.copy_modified( + expanded = t.copy_modified( arg_types=self.expand_types(t.arg_types[:-2]) + repl.arg_types, arg_kinds=t.arg_kinds[:-2] + repl.arg_kinds, arg_names=t.arg_names[:-2] + repl.arg_names, @@ -358,6 +365,11 @@ def visit_callable_type(self, t: CallableType) -> CallableType: imprecise_arg_kinds=(t.imprecise_arg_kinds or repl.imprecise_arg_kinds), variables=[*repl.variables, *t.variables], ) + var_arg = expanded.var_arg() + if var_arg is not None and isinstance(var_arg.typ, UnpackType): + # Sometimes we get new unpacks after expanding ParamSpec. + expanded.normalize_trivial_unpack() + return expanded elif isinstance(repl, ParamSpecType): # We're substituting one ParamSpec for another; this can mean that the prefix # changes, e.g. substitute Concatenate[int, P] in place of Q. diff --git a/mypy/semanal_typeargs.py b/mypy/semanal_typeargs.py index 02cb1b1f6128..dbf5136afa1b 100644 --- a/mypy/semanal_typeargs.py +++ b/mypy/semanal_typeargs.py @@ -15,7 +15,7 @@ from mypy.message_registry import INVALID_PARAM_SPEC_LOCATION, INVALID_PARAM_SPEC_LOCATION_NOTE from mypy.messages import format_type from mypy.mixedtraverser import MixedTraverserVisitor -from mypy.nodes import ARG_STAR, Block, ClassDef, Context, FakeInfo, FuncItem, MypyFile +from mypy.nodes import Block, ClassDef, Context, FakeInfo, FuncItem, MypyFile from mypy.options import Options from mypy.scope import Scope from mypy.subtypes import is_same_type, is_subtype @@ -104,15 +104,7 @@ def visit_tuple_type(self, t: TupleType) -> None: def visit_callable_type(self, t: CallableType) -> None: super().visit_callable_type(t) - # Normalize trivial unpack in var args as *args: *tuple[X, ...] -> *args: X - if t.is_var_arg: - star_index = t.arg_kinds.index(ARG_STAR) - star_type = t.arg_types[star_index] - if isinstance(star_type, UnpackType): - p_type = get_proper_type(star_type.type) - if isinstance(p_type, Instance): - assert p_type.type.fullname == "builtins.tuple" - t.arg_types[star_index] = p_type.args[0] + t.normalize_trivial_unpack() def visit_instance(self, t: Instance) -> None: super().visit_instance(t) diff --git a/mypy/types.py b/mypy/types.py index 3f764a5cc49e..52f8a8d63f09 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -2084,6 +2084,17 @@ def param_spec(self) -> ParamSpecType | None: prefix = Parameters(self.arg_types[:-2], self.arg_kinds[:-2], self.arg_names[:-2]) return arg_type.copy_modified(flavor=ParamSpecFlavor.BARE, prefix=prefix) + def normalize_trivial_unpack(self) -> None: + # Normalize trivial unpack in var args as *args: *tuple[X, ...] -> *args: X in place. + if self.is_var_arg: + star_index = self.arg_kinds.index(ARG_STAR) + star_type = self.arg_types[star_index] + if isinstance(star_type, UnpackType): + p_type = get_proper_type(star_type.type) + if isinstance(p_type, Instance): + assert p_type.type.fullname == "builtins.tuple" + self.arg_types[star_index] = p_type.args[0] + def with_unpacked_kwargs(self) -> NormalizedCallableType: if not self.unpack_kwargs: return cast(NormalizedCallableType, self) @@ -2113,7 +2124,7 @@ def with_normalized_var_args(self) -> Self: if not isinstance(unpacked, TupleType): # Note that we don't normalize *args: *tuple[X, ...] -> *args: X, # this should be done once in semanal_typeargs.py for user-defined types, - # and we ourselves should never construct such type. + # and we ourselves rarely construct such type. return self unpack_index = find_unpack_in_list(unpacked.items) if unpack_index == 0 and len(unpacked.items) > 1: diff --git a/test-data/unit/check-typevar-tuple.test b/test-data/unit/check-typevar-tuple.test index 49298114e069..ea692244597c 100644 --- a/test-data/unit/check-typevar-tuple.test +++ b/test-data/unit/check-typevar-tuple.test @@ -2407,3 +2407,56 @@ reveal_type(x) # N: Revealed type is "__main__.C[builtins.str, builtins.int]" reveal_type(C(f)) # N: Revealed type is "__main__.C[builtins.str, builtins.int, builtins.int, builtins.int, builtins.int]" C[()] # E: At least 1 type argument(s) expected, none given [builtins fixtures/tuple.pyi] + +[case testTypeVarTupleAgainstParamSpecActualSuccess] +from typing import Generic, TypeVar, TypeVarTuple, Unpack, Callable, Tuple, List +from typing_extensions import ParamSpec + +R = TypeVar("R") +P = ParamSpec("P") + +class CM(Generic[R]): ... +def cm(fn: Callable[P, R]) -> Callable[P, CM[R]]: ... + +Ts = TypeVarTuple("Ts") +@cm +def test(*args: Unpack[Ts]) -> Tuple[Unpack[Ts]]: ... + +reveal_type(test) # N: Revealed type is "def [Ts] (*args: Unpack[Ts`-1]) -> __main__.CM[Tuple[Unpack[Ts`-1]]]" +reveal_type(test(1, 2, 3)) # N: Revealed type is "__main__.CM[Tuple[Literal[1]?, Literal[2]?, Literal[3]?]]" +[builtins fixtures/tuple.pyi] + +[case testTypeVarTupleAgainstParamSpecActualFailedNoCrash] +from typing import Generic, TypeVar, TypeVarTuple, Unpack, Callable, Tuple, List +from typing_extensions import ParamSpec + +R = TypeVar("R") +P = ParamSpec("P") + +class CM(Generic[R]): ... +def cm(fn: Callable[P, List[R]]) -> Callable[P, CM[R]]: ... + +Ts = TypeVarTuple("Ts") +@cm # E: Argument 1 to "cm" has incompatible type "Callable[[VarArg(Unpack[Ts])], Tuple[Unpack[Ts]]]"; expected "Callable[[VarArg(Never)], List[Never]]" +def test(*args: Unpack[Ts]) -> Tuple[Unpack[Ts]]: ... + +reveal_type(test) # N: Revealed type is "def (*args: Never) -> __main__.CM[Never]" +[builtins fixtures/tuple.pyi] + +[case testTypeVarTupleAgainstParamSpecActualPrefix] +from typing import Generic, TypeVar, TypeVarTuple, Unpack, Callable, Tuple, List +from typing_extensions import ParamSpec, Concatenate + +R = TypeVar("R") +P = ParamSpec("P") +T = TypeVar("T") + +class CM(Generic[R]): ... +def cm(fn: Callable[Concatenate[T, P], R]) -> Callable[Concatenate[List[T], P], CM[R]]: ... + +Ts = TypeVarTuple("Ts") +@cm +def test(x: T, *args: Unpack[Ts]) -> Tuple[T, Unpack[Ts]]: ... + +reveal_type(test) # N: Revealed type is "def [T, Ts] (builtins.list[T`2], *args: Unpack[Ts`-2]) -> __main__.CM[Tuple[T`2, Unpack[Ts`-2]]]" +[builtins fixtures/tuple.pyi]