Skip to content

Commit

Permalink
Merge branch 'tusharsadhwani:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
yryuvraj committed Mar 11, 2024
2 parents 54182c8 + 12429a0 commit e275ecb
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 13 deletions.
118 changes: 107 additions & 11 deletions src/interpreted/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import sys
from collections import deque
from typing import Any
from typing import Any, Iterable
from unittest import mock

from interpreted import nodes
Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(self, parent=None) -> None:
self.set("int", Int())
self.set("float", Float())
self.set("deque", DequeConstructor())
self.set("enumerate", Enumerate())

def get(self, name) -> Any:
return self.data.get(name, NOT_SET)
Expand Down Expand Up @@ -122,6 +124,20 @@ def call(self, _: Interpreter, args: list[Object]) -> Object:
raise InterpreterError(f"{type(item).__name__} has no len()")


class Enumerate(Function):
def as_string(self) -> str:
return "<function 'enumerate'>"

def arg_count(self) -> int:
return 1

def call(self, _: Interpreter, args: list[Object]) -> Object:
super().ensure_args(args)
# We don't have generator support yet :^)
pairs = [Tuple([Value(idx), val]) for idx, val in enumerate(args[0])]
return List(pairs)


class Int(Function):
def as_string(self) -> str:
return "<function 'int'>"
Expand Down Expand Up @@ -257,6 +273,24 @@ def call(self, _: Interpreter, args: list[Object]) -> None:
self.wrapper._data.append(item)


class Items(Function):
def __init__(self, wrapper: Dict) -> None:
super().__init__()
self.wrapper = wrapper

def as_string(self) -> str:
return f"<method 'items' of {self.wrapper.repr()}>"

def arg_count(self) -> int:
return 0

def call(self, _: Interpreter, args: list[Object]) -> Any:
super().ensure_args(args)
# We don't have generator support yet :^)
pairs = [Tuple(key_value_pair) for key_value_pair in self.wrapper._dict.items()]
return List(pairs)


class PopLeft(Function):
def __init__(self, deque: Deque) -> None:
super().__init__()
Expand Down Expand Up @@ -354,29 +388,35 @@ def call(self, _: Interpreter, args: list[Object]) -> Value:


class List(Object):
def __init__(self, elements) -> None:
def __init__(self, elements: Iterable[Object]) -> None:
super().__init__()
self._data = elements
self.methods["append"] = Append(self)

def as_string(self) -> str:
return "[" + ", ".join(item.repr() for item in self._data) + "]"

def __iter__(self) -> Iterable[Object]:
return iter(self._data)


class Tuple(Object):
def __init__(self, elements) -> None:
def __init__(self, elements: Iterable[Object]) -> None:
super().__init__()
self._data = elements

def as_string(self) -> str:
return "(" + ", ".join(item.repr() for item in self._data) + ")"

def __iter__(self) -> Iterable[Object]:
return iter(self._data)


class Dict(Object):
def __init__(self, keys: list[Object], values: list[Object]) -> None:
super().__init__()

self._dict = {key: value for key, value in zip(keys, values, strict=True)}
self._dict = {key: value for key, value in zip(keys, values)}
self.methods["items"] = Items(self)

def as_string(self) -> str:
return (
Expand All @@ -387,6 +427,9 @@ def as_string(self) -> str:
+ "}"
)

def __iter__(self) -> Iterable[Object]:
return iter(list(self._dict))


def is_truthy(obj: Object) -> bool:
if isinstance(obj, Value):
Expand Down Expand Up @@ -486,14 +529,16 @@ def visit_FunctionDef(self, node: FunctionDef) -> None:

self.scope.set(node.name, function)

def visit_Assign(self, node: Assign) -> None:
value = self.visit(node.value)
assert len(node.targets) == 1 # TODO
target = node.targets[0]

def assign(self, target: Node, value: Object) -> None:
if isinstance(target, Name):
self.scope.set(target.id, value)

elif isinstance(target, (nodes.List, nodes.Tuple)) and isinstance(
value, (List, Tuple, Deque, Dict)
):
for element, value in zip(target.elements, value):
self.assign(element, value)

elif isinstance(target, Subscript):
obj = self.visit(target.value)

Expand All @@ -517,7 +562,14 @@ def visit_Assign(self, node: Assign) -> None:
)

else:
raise NotImplementedError(target) # TODO
raise NotImplementedError(target, value) # TODO

def visit_Assign(self, node: Assign) -> None:
value = self.visit(node.value)
assert len(node.targets) == 1 # TODO
target = node.targets[0]

self.assign(target, value)

def visit_AugAssign(self, node: AugAssign) -> None:
increment = self.visit(node.value)
Expand All @@ -544,6 +596,29 @@ def visit_If(self, node: If) -> None:
for stmt in node.orelse:
self.visit(stmt)

def visit_For(self, node: nodes.For) -> None:
if isinstance(node.iterable, (nodes.List, nodes.Tuple)):
elements = [self.visit(element) for element in node.iterable.elements]
elif isinstance(node.iterable, nodes.Dict):
elements = [self.visit(element) for element in node.iterable.keys]
else:
elements = self.visit(node.iterable)
if not isinstance(elements, (List, Tuple, Deque, Dict)):
raise InterpreterError(
f"Object of type {type(elements).__name__} is not iterable"
)

for element in elements:
self.assign(node.target, element)

for stmt in node.body:
try:
self.visit(stmt)
except Break:
return
except Continue:
break

def visit_While(self, node: While) -> None:
while is_truthy(self.visit(node.condition)):
for stmt in node.body:
Expand Down Expand Up @@ -792,3 +867,24 @@ def interpret(source: str) -> None:
return

Interpreter().visit(module)


def main() -> None:
source = sys.stdin.read()
module = interpret(source)
if module is None:
return

if "--pretty" in sys.argv:
try:
import black
except ImportError:
print("Error: `black` needs to be installed for `--pretty` to work.")

print(black.format_str(repr(module), mode=black.Mode()))
else:
print(module)


if __name__ == "__main__":
main()
32 changes: 31 additions & 1 deletion src/interpreted/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ def parse_multiline_statement(self) -> FunctionDef | For | If | While:
if keyword == "while":
return self.parse_while()

# TODO: for
if keyword == "for":
return self.parse_for()

raise NotImplementedError()

def parse_function_def(self) -> FunctionDef:
Expand Down Expand Up @@ -285,6 +287,34 @@ def parse_while(self) -> While:

return While(condition=condition, body=body, orelse=orelse)

def parse_for(self) -> For:
targets = []
targets.append(self.parse_primary())
while self.match_op(","):
# as soon as we see the first `in` keyword, we assume target to have ended
if self.peek().token_type == TokenType.NAME and self.peek().string == "in":
break

targets.append(self.parse_primary())

if len(targets) == 1:
target = targets[0]
else:
target = Tuple(targets)

self.expect_name("in")

expressions = self.parse_expressions()
if len(expressions) == 1:
iterable = expressions[0]
else:
iterable = Tuple(expressions)

self.expect_op(":")
body = self.parse_block()

return For(target=target, iterable=iterable, body=body, orelse=None)

def parse_block(self) -> list[Statement]:
self.expect(TokenType.NEWLINE)
self.expect(TokenType.INDENT)
Expand Down
75 changes: 75 additions & 0 deletions tests/interpreted_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,81 @@ def test_interpret(source, output) -> None:
assert process.stdout.decode() == dedent(output)


@pytest.mark.parametrize(
("source", "output"),
(
(
"""\
for e in [1,2]:
print(e)
""",
"1\n2\n",
),
(
"""\
lst = ['test','test123']
for e in lst:
print(e)
""",
"test\ntest123\n",
),
(
"""\
for x in 1, 2:
print(x)
""",
"1\n2\n",
),
(
"""\
dct = { "one": 1, "two": 2 }
for k in dct:
print(k, dct[k])
""",
"one 1\ntwo 2\n",
),
(
"""\
dct = { "one": 1, "two": 2 }
for k,v in dct.items():
print(k, v)
""",
"one 1\ntwo 2\n",
),
(
"""\
dct = { "one": 1, "two": 2 }
for k in dct.items():
print(k)
""",
"('one', 1)\n('two', 2)\n",
),
(
"""\
dct = { "one": 1, "two": 2 }
for idx, tup in enumerate(dct):
print(idx, tup)
""",
"0 one\n1 two\n",
),
),
)
def test_for(source, output) -> None:
"""Tests the interpreter CLI."""
with tempfile.NamedTemporaryFile("w+") as file:
file.write(dedent(source))
file.seek(0)

process = subprocess.run(
["interpreted", file.name],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)

assert process.stderr == b""
assert process.stdout.decode() == dedent(output)


def test_file_not_found() -> None:
"""Tests the file not found prompt."""
process = subprocess.run(
Expand Down
44 changes: 44 additions & 0 deletions tests/parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
Compare,
Constant,
ExprStmt,
For,
Import,
ImportFrom,
Module,
Name,
Tuple,
While,
alias,
)
Expand Down Expand Up @@ -129,6 +131,48 @@
]
),
),
(
"""\
for a in b:
42
for i, j in x, t, u in y in a:
print(1)
""",
Module(
body=[
For(
target=Name(id="a"),
iterable=Name(id="b"),
body=[ExprStmt(value=Constant(value=42))],
orelse=None,
),
For(
target=Tuple(elements=[Name(id="i"), Name(id="j")]),
iterable=Tuple(
elements=[
Name(id="x"),
Name(id="t"),
Compare(
left=Compare(
left=Name(id="u"), op="in", right=Name(id="y")
),
op="in",
right=Name(id="a"),
),
]
),
body=[
ExprStmt(
value=Call(
function=Name(id="print"), args=[Constant(value=1)]
)
)
],
orelse=None,
),
]
),
),
),
)
def test_parser(source: str, tree: Module) -> None:
Expand Down
Loading

0 comments on commit e275ecb

Please sign in to comment.