diff --git a/src/interpreted/interpreter.py b/src/interpreted/interpreter.py index 77ff973..0b6d688 100644 --- a/src/interpreted/interpreter.py +++ b/src/interpreted/interpreter.py @@ -34,7 +34,9 @@ class Scope: - def __init__(self): + def __init__(self, parent=None) -> None: + self.data = {} + self.parent = parent self.set("print", Print()) self.set("len", Len()) self.set("int", Int()) @@ -42,10 +44,10 @@ def __init__(self): self.set("deque", DequeConstructor()) def get(self, name) -> Any: - return getattr(self, name, NOT_SET) + return self.data.get(name, NOT_SET) def set(self, name, value) -> None: - setattr(self, name, value) + self.data[name] = value class InterpreterError(Exception): @@ -170,8 +172,14 @@ def __init__(self, value: Object) -> None: class UserFunction(Function): - def __init__(self, definition: FunctionDef, current_globals: Scope) -> None: + def __init__( + self, + definition: FunctionDef, + parent_scope: Scope, + current_globals: Scope, + ) -> None: self.definition = definition + self.parent_scope = parent_scope self.current_globals = current_globals def as_string(self) -> str: @@ -183,10 +191,10 @@ def arg_count(self) -> int: def call(self, interpreter: Interpreter, args: list[Object]) -> Object: super().ensure_args(args) - parent_scope = interpreter.scope + current_scope = interpreter.scope parent_globals = interpreter.globals - function_scope = Scope() + function_scope = Scope(parent=self.parent_scope) interpreter.globals = self.current_globals interpreter.scope = function_scope @@ -201,7 +209,7 @@ def call(self, interpreter: Interpreter, args: list[Object]) -> Object: return ret.value finally: - interpreter.scope = parent_scope + interpreter.scope = current_scope interpreter.globals = parent_globals return Value(None) @@ -424,7 +432,7 @@ def visit_Import(self, node: Import) -> None: self.scope = parent_scope self.globals = parent_globals - module_obj = Module(members=vars(module_scope)) + module_obj = Module(members=module_scope.data) self.scope.set(name, module_obj) @@ -451,7 +459,7 @@ def visit_ImportFrom(self, node: ImportFrom) -> None: for alias in node.names: name = alias.name if name == "*": - for member, value in vars(module_scope).items(): + for member, value in module_scope.data.items(): self.scope.set(member, value) return @@ -462,7 +470,8 @@ def visit_ImportFrom(self, node: ImportFrom) -> None: self.scope.set(name, member) def visit_FunctionDef(self, node: FunctionDef) -> None: - function = UserFunction(node, self.globals) + parent_scope = self.scope + function = UserFunction(node, parent_scope, self.globals) self.scope.set(node.name, function) def visit_Assign(self, node: Assign) -> None: @@ -734,12 +743,17 @@ def visit_Attribute(self, node: Attribute) -> Object: def visit_Name(self, node: Name) -> Value: name = node.id - value = self.scope.get(name) + current_scope = self.scope + while current_scope is not None: + value = current_scope.get(name) + if value is NOT_SET: + current_scope = current_scope.parent + else: + return value + value = self.globals.get(name) if value is NOT_SET: - value = self.globals.get(name) - if value is NOT_SET: - raise InterpreterError(f"{name!r} is not defined") + raise InterpreterError(f"{name!r} is not defined") return value diff --git a/tests/interpreted_test.py b/tests/interpreted_test.py index d9197b9..275964d 100644 --- a/tests/interpreted_test.py +++ b/tests/interpreted_test.py @@ -75,6 +75,26 @@ def foo(x): """, "a\nbc\nab\nabc\nb\n", ), + ( + """\ + x = 5 + + def bar(): + x = 10 + + def baz(): + def foo(): + print(x) + + return foo + + return baz + + foo = bar()() + foo() + """, + "10\n", + ), ), ) def test_interpret(source, output) -> None: