Skip to content

Commit

Permalink
Fix edge case of import *
Browse files Browse the repository at this point in the history
  • Loading branch information
tusharsadhwani committed Oct 11, 2023
1 parent 7fcc24c commit f88b48e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
17 changes: 11 additions & 6 deletions src/interpreted/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@
class Scope:
def __init__(self, parent=None) -> None:
self.data = {}
self.parent = parent
self.set("print", Print())
self.set("len", Len())
self.set("int", Int())
self.set("float", Float())
self.set("deque", DequeConstructor())
self.parent = parent

def get(self, name) -> Any:
return self.data.get(name, NOT_SET)
Expand Down Expand Up @@ -172,7 +172,12 @@ def __init__(self, value: Object) -> None:


class UserFunction(Function):
def __init__(self, definition: FunctionDef, parent_scope: Scope, current_globals: Scope) -> None:
def __init__(
self,
definition: FunctionDef,
parent_scope: Scope,
current_globals: Scope,
) -> None:
self.definition = definition
self.parent_scope = parent_scope
self.current_globals = current_globals
Expand Down Expand Up @@ -427,7 +432,7 @@ def visit_Import(self, node: Import) -> None:
self.scope = parent_scope
self.globals = parent_globals

module_obj = Module(members=vars(module_scope))
module_obj = Module(members=module_scope.data)

self.scope.set(name, module_obj)

Expand All @@ -454,7 +459,7 @@ def visit_ImportFrom(self, node: ImportFrom) -> None:
for alias in node.names:
name = alias.name
if name == "*":
for member, value in vars(module_scope).items():
for member, value in module_scope.data.items():
self.scope.set(member, value)
return

Expand Down Expand Up @@ -745,11 +750,11 @@ def visit_Name(self, node: Name) -> Value:
current_scope = current_scope.parent
else:
return value

value = self.globals.get(name)
if value is NOT_SET:
raise InterpreterError(f"{name!r} is not defined")

return value

def visit_List(self, node: nodes.List) -> List:
Expand Down
28 changes: 16 additions & 12 deletions tests/interpreted_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,24 @@ def foo(x):
),
(
"""\
x = 5
def bar():
x = 10
def baz():
def foo():
print(x)
return foo
return baz
foo = bar()()
foo()
x = 5
def bar():
x = 10
def baz():
def foo():
print(x)
return foo
return baz
foo = bar()()
foo()
""",
"10\n",
)
),
),
)
def test_interpret(source, output) -> None:
Expand Down

0 comments on commit f88b48e

Please sign in to comment.