Skip to content

Commit

Permalink
Distinguish between local and remote exceptions when stopping ephemer…
Browse files Browse the repository at this point in the history
…al app (#2071)

* Distinguish between local and remote exceptions when stopping ephemeral app

* Add comments so we know about the implicit dependency here
  • Loading branch information
mwaskom committed Aug 6, 2024
1 parent 109f528 commit cb6b242
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
14 changes: 13 additions & 1 deletion modal/_traceback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright Modal Labs 2022
import functools
import re
import traceback
import warnings
from types import TracebackType
Expand Down Expand Up @@ -37,6 +38,8 @@ def extract_traceback(exc: BaseException, task_id: str) -> Tuple[TBDictType, Lin
# container. This means we've reached the end of the local traceback.
if file.startswith("<"):
break
# We rely on this specific filename format when inferring where the exception was raised
# in various other exception-related code
cur.tb_frame.f_code.co_filename = f"<{task_id}>:{file}"
cur = cur.tb_next

Expand Down Expand Up @@ -67,7 +70,7 @@ def append_modal_tb(exc: BaseException, tb_dict: TBDictType, line_cache: LineCac
setattr(exc, "__line_cache__", line_cache)


def reduce_traceback_to_user_code(tb: TracebackType, user_source: str) -> TracebackType:
def reduce_traceback_to_user_code(tb: Optional[TracebackType], user_source: str) -> TracebackType:
"""Return a traceback that does not contain modal entrypoint or synchronicity frames."""
# Step forward all the way through the traceback and drop any synchronicity frames
tb_root = tb
Expand All @@ -94,6 +97,15 @@ def reduce_traceback_to_user_code(tb: TracebackType, user_source: str) -> Traceb
return tb


def traceback_contains_remote_call(tb: Optional[TracebackType]) -> bool:
"""Inspect the traceback stack to determine whether an error was raised locally or remotely."""
while tb is not None:
if re.match(r"^<ta-[0-9A-Z]{26}>:", tb.tb_frame.f_code.co_filename):
return True
tb = tb.tb_next
return False


@group()
def _render_stack(self, stack: Stack) -> RenderResult:
"""Patched variant of rich.Traceback._render_stack that uses the line from the modal StackSummary,
Expand Down
6 changes: 5 additions & 1 deletion modal/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ._pty import get_pty_info
from ._resolver import Resolver
from ._sandbox_shell import connect_to_sandbox
from ._traceback import traceback_contains_remote_call
from ._utils.async_utils import TaskContext, synchronize_api
from ._utils.grpc_utils import retry_transient_errors
from ._utils.name_utils import check_object_name, is_valid_tag
Expand Down Expand Up @@ -328,7 +329,10 @@ async def _run_app(
if isinstance(exc_info, KeyboardInterrupt):
reason = api_pb2.APP_DISCONNECT_REASON_KEYBOARD_INTERRUPT
elif exc_info is not None:
reason = api_pb2.APP_DISCONNECT_REASON_LOCAL_EXCEPTION
if traceback_contains_remote_call(exc_info.__traceback__):
reason = api_pb2.APP_DISCONNECT_REASON_REMOTE_EXCEPTION
else:
reason = api_pb2.APP_DISCONNECT_REASON_LOCAL_EXCEPTION
else:
reason = api_pb2.APP_DISCONNECT_REASON_ENTRYPOINT_COMPLETED

Expand Down
28 changes: 27 additions & 1 deletion test/traceback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from traceback import extract_tb
from typing import Dict, List, Tuple

from modal._traceback import append_modal_tb, extract_traceback, reduce_traceback_to_user_code
from modal._traceback import (
append_modal_tb,
extract_traceback,
reduce_traceback_to_user_code,
traceback_contains_remote_call,
)
from modal._vendor import tblib

from .supports.raise_error import raise_error
Expand Down Expand Up @@ -133,3 +138,24 @@ def test_reduce_traceback_to_user_code(user_mode):
assert f.f_code.co_name == "execute"

assert tb_out.tb_next.tb_next is None


def test_traceback_contains_remote_call():
stack = [
("/home/foobar/code/script.py", "f"),
("/usr/local/venv/modal.py", "local"),
]

tb = tblib.Traceback.from_dict(tb_dict_from_stack_dicts(make_tb_stack(stack)))
assert not traceback_contains_remote_call(tb)

task_id = "ta-0123456789ABCDEFGHILJKMNOP"
stack.extend(
[
(f"<{task_id}>:/usr/local/lib/python3.11/importlib/__init__.py", ""),
("/root/script.py", ""),
]
)

tb = tblib.Traceback.from_dict(tb_dict_from_stack_dicts(make_tb_stack(stack)))
assert traceback_contains_remote_call(tb)

0 comments on commit cb6b242

Please sign in to comment.