Skip to content

Commit

Permalink
rewrite: load and store instance attributes.
Browse files Browse the repository at this point in the history
* Splits abstract.Instance into FrozenInstance and MutableInstance, so that we
  can easily distinguish between canonical instances created by
  BaseClass.instantiate() that should ignore any attribute setting done outside
  of __new__ and __init__, and instances created in the course of bytecode
  analysis that should respect all modifications.
* Implements opcodes for loading and storing attributes.

PiperOrigin-RevId: 615684582
  • Loading branch information
rchen152 committed Mar 14, 2024
1 parent e9a60b7 commit 2437282
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 24 deletions.
1 change: 1 addition & 0 deletions pytype/rewrite/abstract/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ py_test(
DEPS
.base
.classes
.functions
)

py_library(
Expand Down
3 changes: 3 additions & 0 deletions pytype/rewrite/abstract/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
NULL = _base.NULL

BaseClass = _classes.BaseClass
FrozenInstance = _classes.FrozenInstance
InterpreterClass = _classes.InterpreterClass
MutableInstance = _classes.MutableInstance
BUILD_CLASS = _classes.BUILD_CLASS

Args = _functions.Args
BoundFunction = _functions.BoundFunction
InterpreterFunction = _functions.InterpreterFunction

get_atomic_constant = _utils.get_atomic_constant
9 changes: 8 additions & 1 deletion pytype/rewrite/abstract/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Base abstract representation of Python values."""

from typing import Generic, Set, TypeVar
from typing import Generic, Optional, Set, TypeVar

from pytype.rewrite.flow import variables
from typing_extensions import Self
Expand All @@ -13,6 +13,13 @@ class BaseValue:
def to_variable(self: Self) -> variables.Variable[Self]:
return variables.Variable.from_value(self)

def get_attribute(self, name: str) -> Optional['BaseValue']:
del name # unused
return None

def set_attribute(self, name: str, value: 'BaseValue') -> None:
del name, value # unused


class PythonConstant(BaseValue, Generic[_T]):

Expand Down
79 changes: 71 additions & 8 deletions pytype/rewrite/abstract/classes.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
"""Abstract representations of classes."""

from typing import List, Mapping, Optional, Sequence
import dataclasses

from typing import Dict, List, Optional, Protocol, Sequence

from pytype.rewrite.abstract import base
from pytype.rewrite.abstract import functions as functions_lib


class _HasMembers(Protocol):

members: Dict[str, base.BaseValue]


@dataclasses.dataclass
class ClassCallReturn:

instance: 'MutableInstance'

def get_return_value(self):
return self.instance


class BaseClass(base.BaseValue):
"""Base representation of a class."""

def __init__(self, name: str, members: Mapping[str, base.BaseValue]):
def __init__(self, name: str, members: Dict[str, base.BaseValue]):
self.name = name
self.members = members

Expand All @@ -30,7 +46,7 @@ def __repr__(self):
def get_attribute(self, name: str) -> Optional[base.BaseValue]:
return self.members.get(name)

def instantiate(self) -> 'Instance':
def instantiate(self) -> 'FrozenInstance':
"""Creates an instance of this class."""
for setup_method_name in self.setup_methods:
setup_method = self.get_attribute(setup_method_name)
Expand All @@ -40,19 +56,31 @@ def instantiate(self) -> 'Instance':
if constructor:
raise NotImplementedError('Custom __new__')
else:
instance = Instance(self)
instance = MutableInstance(self)
for initializer_name in self.initializers:
initializer = self.get_attribute(initializer_name)
if isinstance(initializer, functions_lib.InterpreterFunction):
_ = initializer.bind_to(instance).analyze()
return instance
return instance.freeze()

def call(self, args: functions_lib.Args) -> ClassCallReturn:
constructor = self.get_attribute(self.constructor)
if constructor:
raise NotImplementedError('Custom __new__')
else:
instance = MutableInstance(self)
for initializer_name in self.initializers:
initializer = self.get_attribute(initializer_name)
if isinstance(initializer, functions_lib.InterpreterFunction):
_ = initializer.bind_to(instance).call(args)
return ClassCallReturn(instance)


class InterpreterClass(BaseClass):
"""Class defined in the current module."""

def __init__(
self, name: str, members: Mapping[str, base.BaseValue],
self, name: str, members: Dict[str, base.BaseValue],
functions: Sequence[functions_lib.InterpreterFunction],
classes: Sequence['InterpreterClass']):
super().__init__(name, members)
Expand All @@ -65,14 +93,49 @@ def __repr__(self):
return f'InterpreterClass({self.name})'


class Instance(base.BaseValue):
class MutableInstance(base.BaseValue):
"""Instance of a class."""

def __init__(self, cls: BaseClass):
self.cls = cls
self.members: Dict[str, base.BaseValue] = {}

def __repr__(self):
return f'Instance({self.cls.name})'
return f'MutableInstance({self.cls.name})'

def get_attribute(self, name: str) -> Optional[base.BaseValue]:
if name in self.members:
return self.members[name]
cls_attribute = self.cls.get_attribute(name)
if isinstance(cls_attribute, functions_lib.SimpleFunction):
return cls_attribute.bind_to(self)
return cls_attribute

def set_attribute(self, name: str, value: base.BaseValue) -> None:
if name in self.members:
raise NotImplementedError(f'Attribute already set: {name}')
self.members[name] = value

def freeze(self) -> 'FrozenInstance':
return FrozenInstance(self)


class FrozenInstance(base.BaseValue):
"""Frozen instance of a class.
This is used by BaseClass.instantiate() to create a snapshot of an instance
whose members map cannot be further modified.
"""

def __init__(self, instance: MutableInstance):
self._underlying = instance

@property
def cls(self):
return self._underlying.cls

def get_attribute(self, name: str) -> Optional[base.BaseValue]:
return self._underlying.get_attribute(name)


BUILD_CLASS = base.Singleton('BUILD_CLASS')
36 changes: 36 additions & 0 deletions pytype/rewrite/abstract/classes_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pytype.rewrite.abstract import base
from pytype.rewrite.abstract import classes
from pytype.rewrite.abstract import functions

import unittest

Expand All @@ -20,6 +21,41 @@ def test_instantiate(self):
instance = cls.instantiate()
self.assertEqual(instance.cls, cls)

def test_call(self):
cls = classes.BaseClass('X', {})
instance = cls.call(functions.Args()).get_return_value()
self.assertEqual(instance.cls, cls)


class MutableInstanceTest(unittest.TestCase):

def test_get_instance_attribute(self):
cls = classes.BaseClass('X', {})
instance = classes.MutableInstance(cls)
instance.members['x'] = base.PythonConstant(3)
self.assertEqual(instance.get_attribute('x'), base.PythonConstant(3))

def test_get_class_attribute(self):
cls = classes.BaseClass('X', {'x': base.PythonConstant(3)})
instance = classes.MutableInstance(cls)
self.assertEqual(instance.get_attribute('x'), base.PythonConstant(3))

def test_set_attribute(self):
cls = classes.BaseClass('X', {})
instance = classes.MutableInstance(cls)
instance.set_attribute('x', base.PythonConstant(3))
self.assertEqual(instance.members['x'], base.PythonConstant(3))


class FrozenInstanceTest(unittest.TestCase):

def test_get_attribute(self):
cls = classes.BaseClass('X', {})
mutable_instance = classes.MutableInstance(cls)
mutable_instance.set_attribute('x', base.PythonConstant(3))
instance = mutable_instance.freeze()
self.assertEqual(instance.get_attribute('x'), base.PythonConstant(3))


if __name__ == '__main__':
unittest.main()
3 changes: 3 additions & 0 deletions pytype/rewrite/abstract/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ def analyze(self) -> Sequence[_HasReturnT]:


class _Frame(Protocol):
"""Protocol for a VM frame."""

final_locals: Mapping[str, base.BaseValue]

def make_child_frame(
self,
Expand Down
1 change: 1 addition & 0 deletions pytype/rewrite/abstract/functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class FakeFrame:

def __init__(self):
self.child_frames = []
self.final_locals = {}

def make_child_frame(self, func, initial_locals):
self.child_frames.append((func, initial_locals))
Expand Down
67 changes: 62 additions & 5 deletions pytype/rewrite/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def run(self) -> None:
# Set the current state to None so that the load_* and store_* methods
# cannot be used to modify finalized locals.
self._current_state = None
self.final_locals = self._final_locals_as_values()
self._finalize_locals()

def store_local(self, name: str, var: _AbstractVariable) -> None:
self._current_state.store_local(name, var)
Expand Down Expand Up @@ -253,7 +253,9 @@ def _call_function(
) -> None:
ret_values = []
for func in func_var.values:
if isinstance(func, abstract.InterpreterFunction):
if isinstance(func, (abstract.InterpreterFunction,
abstract.InterpreterClass,
abstract.BoundFunction)):
frame = func.call(args)
ret_values.append(frame.get_return_value())
elif func is abstract.BUILD_CLASS:
Expand All @@ -262,7 +264,7 @@ def _call_function(
frame = builder.call(abstract.Args())
cls = abstract.InterpreterClass(
name=abstract.get_atomic_constant(name, str),
members=frame.final_locals,
members=dict(frame.final_locals),
functions=frame.functions,
classes=frame.classes,
)
Expand All @@ -273,7 +275,7 @@ def _call_function(
self._stack.push(
variables.Variable(tuple(variables.Binding(v) for v in ret_values)))

def _final_locals_as_values(self) -> Mapping[str, abstract.BaseValue]:
def _finalize_locals(self) -> None:
final_values = {}
for name, var in self._final_locals.items():
values = var.values
Expand All @@ -283,7 +285,39 @@ def _final_locals_as_values(self) -> Mapping[str, abstract.BaseValue]:
final_values[name] = values[0]
else:
raise NotImplementedError('Empty variable not yet supported')
return immutabledict.immutabledict(final_values)
# We've stored SET_ATTR results as local values. Now actually perform the
# attribute setting.
# TODO(b/241479600): If we're deep in a stack of method calls, we should
# instead merge the attribute values into the parent frame so that any
# conditions on the bindings are preserved.
for name, value in final_values.items():
target_name, dot, attr_name = name.rpartition('.')
if not dot or target_name not in self._final_locals:
continue
for target in self._final_locals[target_name].values:
target.set_attribute(attr_name, value)
self.final_locals = immutabledict.immutabledict(final_values)

def _load_attr(
self, target_var: _AbstractVariable, attr_name: str) -> _AbstractVariable:
if target_var.name:
name = f'{target_var.name}.{attr_name}'
else:
name = None
try:
# Check if we've stored the attribute in the current frame.
return self.load_local(name)
except KeyError as e:
# We're loading an attribute without a locally stored value.
attr_bindings = []
for target in target_var.values:
attr = target.get_attribute(attr_name)
if not attr:
raise NotImplementedError('Attribute error') from e
# TODO(b/241479600): If there's a condition on the target binding, we
# should copy it.
attr_bindings.append(variables.Binding(attr))
return variables.Variable(tuple(attr_bindings), name)

def byte_RESUME(self, opcode):
del opcode # unused
Expand All @@ -307,6 +341,14 @@ def byte_STORE_GLOBAL(self, opcode):
def byte_STORE_DEREF(self, opcode):
self.store_deref(opcode.argval, self._stack.pop())

def byte_STORE_ATTR(self, opcode):
attr_name = opcode.argval
attr, target = self._stack.popn(2)
if not target.name:
raise NotImplementedError('Missing target name')
full_name = f'{target.name}.{attr_name}'
self.store_local(full_name, attr)

def byte_MAKE_FUNCTION(self, opcode):
if opcode.arg not in (0, pyc_marshal.Flags.MAKE_FUNCTION_HAS_FREE_VARS):
raise NotImplementedError('MAKE_FUNCTION not fully implemented')
Expand Down Expand Up @@ -358,6 +400,21 @@ def byte_LOAD_GLOBAL(self, opcode):
name = opcode.argval
self._stack.push(self.load_global(name))

def byte_LOAD_ATTR(self, opcode):
attr_name = opcode.argval
target_var = self._stack.pop()
self._stack.push(self._load_attr(target_var, attr_name))

def byte_LOAD_METHOD(self, opcode):
method_name = opcode.argval
instance_var = self._stack.pop()
# https://docs.python.org/3/library/dis.html#opcode-LOAD_METHOD says that
# this opcode should push two values onto the stack: either the unbound
# method and its `self` or NULL and the bound method. Since we always
# retrieve a bound method, we push the NULL
self._stack.push(abstract.NULL.to_variable())
self._stack.push(self._load_attr(instance_var, method_name))

def byte_PRECALL(self, opcode):
del opcode # unused

Expand Down
Loading

0 comments on commit 2437282

Please sign in to comment.