Skip to content

Commit

Permalink
Add support for literal enums.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 436347608
  • Loading branch information
Solumin authored and rchen152 committed Mar 22, 2022
1 parent 243c95d commit 2c21609
Show file tree
Hide file tree
Showing 14 changed files with 225 additions and 60 deletions.
5 changes: 1 addition & 4 deletions pytype/abstract/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,10 +884,7 @@ def __hash__(self):

@property
def value(self):
if isinstance(self._instance, _instances.ConcreteValue):
return self._instance
# TODO(b/173742489): Remove this workaround once we support literal enums.
return None
return self._instance

def instantiate(self, node, container=None):
return self._instance.to_variable(node)
Expand Down
27 changes: 18 additions & 9 deletions pytype/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,19 +553,30 @@ def _create_module(self, ast):
members[name] = val
return abstract.Module(self.ctx, ast.name, members, ast)

def _get_literal_value(self, pyval):
def _get_literal_value(self, pyval, subst):
"""Extract and convert the value of a pytd.Literal."""
if isinstance(pyval, pytd.Constant):
# Literal enums are stored as Constants with the name set to the member
# name and the type set to a ClassType pointing to the enum cls.
cls = self.constant_to_value(pyval.type.cls)
_, name = pyval.name.rsplit(".", 1)
# Bad values should have been caught by visitors.VerifyEnumValues.
assert cls.is_enum, f"Non-enum type used in Literal: {cls.official_name}"
assert name in cls, f"Literal enum refers to non-existent member \"{pyval.name}\" of {cls.official_name}"
# The cls has already been converted, so don't try to convert the member.
return abstract_utils.get_atomic_value(cls.members[name])
if pyval == self.ctx.loader.lookup_builtin("builtins.True"):
return True
value = True
elif pyval == self.ctx.loader.lookup_builtin("builtins.False"):
return False
value = False
elif isinstance(pyval, str):
prefix, value = parser_constants.STRING_RE.match(pyval).groups()[:2]
value = value[1:-1] # remove quotation marks
if "b" in prefix:
value = str(value).encode("utf-8")
return value
else:
return pyval
value = pyval
return self.constant_to_value(value, subst, self.ctx.root_node)

def _constant_to_value(self, pyval, subst, get_node):
"""Create a BaseValue that represents a python constant.
Expand Down Expand Up @@ -798,8 +809,7 @@ def _constant_to_value(self, pyval, subst, get_node):
self._convert_cache[key] = instance
return self._convert_cache[key]
elif isinstance(cls, pytd.Literal):
return self.constant_to_value(
self._get_literal_value(cls.value), subst, self.ctx.root_node)
return self._get_literal_value(cls.value, subst)
else:
return self.constant_to_value(cls, subst, self.ctx.root_node)
elif (isinstance(pyval, pytd.GenericType) and
Expand Down Expand Up @@ -846,8 +856,7 @@ def _constant_to_value(self, pyval, subst, get_node):
template, parameters, subst)
return abstract_class(base_cls, type_parameters, self.ctx)
elif isinstance(pyval, pytd.Literal):
value = self.constant_to_value(
self._get_literal_value(pyval.value), subst, self.ctx.root_node)
value = self._get_literal_value(pyval.value, subst)
return abstract.LiteralClass(value, self.ctx)
elif isinstance(pyval, pytd.Annotated):
typ = self.constant_to_value(pyval.base_type, subst, self.ctx.root_node)
Expand Down
1 change: 1 addition & 0 deletions pytype/load_pytd.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def verify(self, mod_ast, *, mod_name=None):
name = mod_name or mod_ast.name
raise BadDependencyError(utils.message(e), name) from e
mod_ast.Visit(visitors.VerifyContainers())
mod_ast.Visit(visitors.VerifyLiterals())

@classmethod
def collect_dependencies(cls, mod_ast):
Expand Down
17 changes: 6 additions & 11 deletions pytype/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,19 +786,14 @@ def _match_instance_against_type(self, left, other_type, subst, view):
"""
if isinstance(other_type, abstract.LiteralClass):
other_value = other_type.value
if other_value and isinstance(left, abstract.ConcreteValue):
if isinstance(left, abstract.ConcreteValue):
return subst if left.pyval == other_value.pyval else None
elif other_value:
# `left` does not contain a concrete value. Literal overloads are
# always followed by at least one non-literal fallback, so we should
# fail here.
return None
elif isinstance(left, abstract.Instance) and left.cls.is_enum:
names_match = left.name == other_value.name
clses_match = left.cls == other_value.cls
return subst if names_match and clses_match else None
else:
# TODO(b/173742489): Remove this workaround once we can match against
# literal enums.
return self._match_type_against_type(
left, other_type.formal_type_parameters[abstract_utils.T], subst,
view)
return None
elif isinstance(other_type, typed_dict.TypedDictClass):
if not self._match_dict_against_typed_dict(left, other_type):
return None
Expand Down
13 changes: 6 additions & 7 deletions pytype/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ def value_instance_to_pytd_type(self, node, v, instance, seen, view):
return self.value_instance_to_pytd_type(
node, v.base_cls, instance, seen, view)
elif isinstance(v, abstract.LiteralClass):
if not v.value:
# TODO(b/173742489): Remove this workaround once we support literal
# enums.
return pytd.AnythingType()
if isinstance(v.value.pyval, (str, bytes)):
if isinstance(v.value, abstract.Instance) and v.value.cls.is_enum:
typ = pytd_utils.NamedTypeWithModule(
v.value.cls.official_name or v.value.cls.name, v.value.cls.module)
value = pytd.Constant(v.value.name, typ)
elif isinstance(v.value.pyval, (str, bytes)):
# Strings are stored as strings of their representations, prefix and
# quotes and all.
value = repr(v.value.pyval)
Expand All @@ -167,8 +167,7 @@ def value_instance_to_pytd_type(self, node, v, instance, seen, view):
else:
# Ints are stored as their literal values. Note that Literal[None] or a
# nested literal will never appear here, since we simplified it to None
# or unnested it, respectively, in typing_overlay. Literal[<enum>] does
# not appear here yet because it is unsupported.
# or unnested it, respectively, in typing_overlay.
assert isinstance(v.value.pyval, int), v.value.pyval
value = v.value.pyval
return pytd.Literal(value)
Expand Down
14 changes: 13 additions & 1 deletion pytype/overlays/enum_overlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,11 @@ def _setup_interpreterclass(self, node, cls):
# creating the class -- pytype complains about recursive types.
member = abstract.Instance(cls, self.ctx)
member_var = member.to_variable(node)
# This makes literal enum equality checks easier. We could check the name
# attribute that we set below, but it's easier to compare these strings.
# Use the fully qualified name to be consistent with how literal enums
# are parsed from type stubs.
member.name = f"{cls.full_name}.{name}"
if "_value_" not in member.members:
if base_type:
args = function.Args(
Expand Down Expand Up @@ -641,8 +646,15 @@ def _setup_pytdclass(self, node, cls):
# Build instances directly, because you can't call instantiate() when
# creating the class -- pytype complains about recursive types.
member = abstract.Instance(cls, self.ctx)
# This makes literal enum equality checks easier. We could check the name
# attribute that we set below, but those aren't real strings.
# Use the fully qualified name to be consistent with how literal enums
# are parsed from type stubs.
member.name = f"{cls.full_name}.{pytd_val.name}"
member.members["name"] = self.ctx.convert.constant_to_var(
pyval=pytd.Constant(name="name", type=self._str_pytd), node=node)
pyval=pytd.Constant(
name="name", type=self._str_pytd, value=pytd_val.name),
node=node)
# Some type stubs may use the class type for enum member values, instead
# of the actual value type. Detect that and use Any.
if pytd_val.type.name == cls.pytd_cls.name:
Expand Down
6 changes: 4 additions & 2 deletions pytype/pyi/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,10 @@ def pytd_literal(parameters: List[Any]) -> pytd.Type:
if pytdgen.is_none(p):
literal_parameters.append(p)
elif isinstance(p, pytd.NamedType):
# TODO(b/173742489): support enums.
literal_parameters.append(pytd.AnythingType())
cls_name = p.name.rsplit(".", 1)[0]
literal_parameters.append(pytd.Literal(
pytd.Constant(name=p.name, type=pytd.NamedType(cls_name))
))
elif isinstance(p, types.Pyval):
literal_parameters.append(p.to_pytd_literal())
elif isinstance(p, pytd.Literal):
Expand Down
9 changes: 0 additions & 9 deletions pytype/pyi/parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2404,21 +2404,12 @@ def test_none(self):
""", "x: None")

def test_enum(self):
# TODO(b/173742489): support enums.
self.check("""
import enum
from typing import Literal
x: Literal[Color.RED]
class Color(enum.Enum):
RED: str
""", """
import enum
from typing import Any
x: Any
class Color(enum.Enum):
RED: str
""")
Expand Down
11 changes: 7 additions & 4 deletions pytype/pytd/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,13 @@ def LeaveConstant(self, node):
def VisitConstant(self, node):
"""Convert a class-level or module-level constant to a string."""
if self.in_literal:
module, _, name = node.name.partition(".")
assert module == "builtins", module
assert name in ("True", "False"), name
return name
# This should be either True, False or an enum. For the booleans, strip
# off the module name. For enums, print the whole name.
if "builtins." in node.name:
_, _, name = node.name.partition(".")
return name
else:
return node.name
return f"{node.name}: {node.type}"

def EnterAlias(self, _):
Expand Down
2 changes: 1 addition & 1 deletion pytype/pytd/pytd.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ def ret(self):
@attr.s(auto_attribs=True, frozen=True, order=False, slots=True,
cache_hash=True)
class Literal(Node, Type):
value: Union[int, str, Type]
value: Union[int, str, Type, Constant]

@property
def name(self):
Expand Down
65 changes: 65 additions & 0 deletions pytype/pytd/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class SymbolLookupError(Exception):
pass


class LiteralValueError(Exception):
pass


# All public elements of pytd_visitors are aliased here so that we can maintain
# the conceptually simpler illusion of having a single visitors module.
ALL_NODE_NAMES = base_visitor.ALL_NODE_NAMES
Expand Down Expand Up @@ -1709,6 +1713,67 @@ def EnterClass(self, node):
t.type_param.full_name))


class VerifyLiterals(Visitor):
"""Visitor for verifying that Literal[object] contains an enum.
Other valid Literal types are checked by the parser, e.g. to make sure no
`float` values are used in Literals. Checking that an object in a Literal is
an enum member is more complex, so it gets its own visitor.
Because this visitor walks up the class hierarchy, it must be run after
ClassType pointers are filled in.
"""

def EnterLiteral(self, node):
value = node.value
if not isinstance(value, pytd.Constant):
# This Literal does not hold an object, no need to check further.
return

if value.name in ("builtins.True", "builtins.False"):
# When outputting `x: Literal[True]` from a source file, we write it as
# a Literal(Constant("builtins.True", type=ClassType("builtins.bool")))
# This is fine and does not need to be checked for enum-ness.
return

typ = value.type
if not isinstance(typ, pytd.ClassType):
# This happens sometimes, e.g. with stdlib type stubs that interact with
# C extensions. (tkinter.pyi, for example.) There's no point in trying to
# handle these case.
return
this_cls = typ.cls
assert this_cls, ("VerifyLiterals visitor must be run after ClassType "
"pointers are filled.")

# The fun part: Walk through each class in the MRO and figure out if it
# inherits from enum.Enum.
stack = [this_cls]
while stack:
cls = stack.pop()
if cls.name == "enum.Enum":
break
# We're only going to handle ClassType and Class here. The other types
# that may appear in ClassType.cls pointers or Class.bases lists are not
# common and may indicate that something is wrong.
if isinstance(cls, pytd.ClassType):
stack.extend(cls.cls.bases)
elif isinstance(cls, pytd.Class):
stack.extend(cls.bases)
else:
n = pytd_utils.Print(node)
msg = f"In {n}: {this_cls.name} is not an enum and cannot be used in typing.Literal"
raise LiteralValueError(msg)

# Second check: The member named in the Literal exists in the enum.
# We know at this point that value.name is "file.enum_class.member_name".
_, member_name = value.name.rsplit(".", 1)
if member_name not in this_cls:
n = pytd_utils.Print(node)
msg = f"In {n}: {value.name} is not a member of enum {this_cls.name}"
raise LiteralValueError(msg)


class ExpandCompatibleBuiltins(Visitor):
"""Ad-hoc inheritance.
Expand Down
49 changes: 48 additions & 1 deletion pytype/tests/test_typing1.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,18 +284,22 @@ def test_pyi_typing_extensions(self):
""", pythonpath=[d.path])
self.assertTypesMatchPytd(ty, "import foo")

# TODO(b/173742489): Include enums once we support looking up local enums.
def test_pyi_value(self):
with file_utils.Tempdir() as d:
d.create_file("foo.pyi", """
import enum
from typing import Literal
class Color(enum.Enum):
RED: str
def f1(x: Literal[True]) -> None: ...
def f2(x: Literal[2]) -> None: ...
def f3(x: Literal[None]) -> None: ...
def f4(x: Literal['hello']) -> None: ...
def f5(x: Literal[b'hello']) -> None: ...
def f6(x: Literal[u'hello']) -> None: ...
def f7(x: Literal[Color.RED]) -> None: ...
""")
self.Check("""
import foo
Expand All @@ -305,6 +309,7 @@ def f6(x: Literal[u'hello']) -> None: ...
foo.f4('hello')
foo.f5(b'hello')
foo.f6(u'hello')
foo.f7(foo.Color.RED)
""", pythonpath=[d.path])

def test_pyi_multiple(self):
Expand Down Expand Up @@ -412,6 +417,48 @@ def f1() -> int: ...
def f2() -> str: ...
""")

def test_illegal_literal_class(self):
# This should be a pyi-error, but checking happens during conversion.
with file_utils.Tempdir() as d:
d.create_file("foo.pyi", """
from typing import Literal
class NotEnum:
A: int
x: Literal[NotEnum.A]
""")
self.CheckWithErrors("""
import foo # pyi-error
""", pythonpath=[d.path])

def test_illegal_literal_class_indirect(self):
with file_utils.Tempdir() as d:
d.create_file("foo.pyi", """
class NotEnum:
A: int
""")
d.create_file("bar.pyi", """
from typing import Literal
import foo
y: Literal[foo.NotEnum.A]
""")
self.CheckWithErrors("""
import bar # pyi-error
""", pythonpath=[d.path])

def test_missing_enum_member(self):
# This should be a pyi-error, but checking happens during conversion.
with file_utils.Tempdir() as d:
d.create_file("foo.pyi", """
import enum
from typing import Literal
class M(enum.Enum):
A: int
x: Literal[M.B]
""")
self.CheckWithErrors("""
import foo # pyi-error
""", pythonpath=[d.path])


if __name__ == "__main__":
test_base.main()
Loading

0 comments on commit 2c21609

Please sign in to comment.