From cb6b2427c91544237275df7aea0d87a635f4643c Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 6 Aug 2024 10:04:12 -0400 Subject: [PATCH] Distinguish between local and remote exceptions when stopping ephemeral app (#2071) * Distinguish between local and remote exceptions when stopping ephemeral app * Add comments so we know about the implicit dependency here --- modal/_traceback.py | 14 +++++++++++++- modal/runner.py | 6 +++++- test/traceback_test.py | 28 +++++++++++++++++++++++++++- 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/modal/_traceback.py b/modal/_traceback.py index c9ee7cbaf..4a7e2d216 100644 --- a/modal/_traceback.py +++ b/modal/_traceback.py @@ -1,5 +1,6 @@ # Copyright Modal Labs 2022 import functools +import re import traceback import warnings from types import TracebackType @@ -37,6 +38,8 @@ def extract_traceback(exc: BaseException, task_id: str) -> Tuple[TBDictType, Lin # container. This means we've reached the end of the local traceback. if file.startswith("<"): break + # We rely on this specific filename format when inferring where the exception was raised + # in various other exception-related code cur.tb_frame.f_code.co_filename = f"<{task_id}>:{file}" cur = cur.tb_next @@ -67,7 +70,7 @@ def append_modal_tb(exc: BaseException, tb_dict: TBDictType, line_cache: LineCac setattr(exc, "__line_cache__", line_cache) -def reduce_traceback_to_user_code(tb: TracebackType, user_source: str) -> TracebackType: +def reduce_traceback_to_user_code(tb: Optional[TracebackType], user_source: str) -> TracebackType: """Return a traceback that does not contain modal entrypoint or synchronicity frames.""" # Step forward all the way through the traceback and drop any synchronicity frames tb_root = tb @@ -94,6 +97,15 @@ def reduce_traceback_to_user_code(tb: TracebackType, user_source: str) -> Traceb return tb +def traceback_contains_remote_call(tb: Optional[TracebackType]) -> bool: + """Inspect the traceback stack to determine whether an error was raised locally or remotely.""" + while tb is not None: + if re.match(r"^:", tb.tb_frame.f_code.co_filename): + return True + tb = tb.tb_next + return False + + @group() def _render_stack(self, stack: Stack) -> RenderResult: """Patched variant of rich.Traceback._render_stack that uses the line from the modal StackSummary, diff --git a/modal/runner.py b/modal/runner.py index 0663a301f..bfd2475db 100644 --- a/modal/runner.py +++ b/modal/runner.py @@ -15,6 +15,7 @@ from ._pty import get_pty_info from ._resolver import Resolver from ._sandbox_shell import connect_to_sandbox +from ._traceback import traceback_contains_remote_call from ._utils.async_utils import TaskContext, synchronize_api from ._utils.grpc_utils import retry_transient_errors from ._utils.name_utils import check_object_name, is_valid_tag @@ -328,7 +329,10 @@ async def _run_app( if isinstance(exc_info, KeyboardInterrupt): reason = api_pb2.APP_DISCONNECT_REASON_KEYBOARD_INTERRUPT elif exc_info is not None: - reason = api_pb2.APP_DISCONNECT_REASON_LOCAL_EXCEPTION + if traceback_contains_remote_call(exc_info.__traceback__): + reason = api_pb2.APP_DISCONNECT_REASON_REMOTE_EXCEPTION + else: + reason = api_pb2.APP_DISCONNECT_REASON_LOCAL_EXCEPTION else: reason = api_pb2.APP_DISCONNECT_REASON_ENTRYPOINT_COMPLETED diff --git a/test/traceback_test.py b/test/traceback_test.py index 13edb41ba..9df0aba04 100644 --- a/test/traceback_test.py +++ b/test/traceback_test.py @@ -4,7 +4,12 @@ from traceback import extract_tb from typing import Dict, List, Tuple -from modal._traceback import append_modal_tb, extract_traceback, reduce_traceback_to_user_code +from modal._traceback import ( + append_modal_tb, + extract_traceback, + reduce_traceback_to_user_code, + traceback_contains_remote_call, +) from modal._vendor import tblib from .supports.raise_error import raise_error @@ -133,3 +138,24 @@ def test_reduce_traceback_to_user_code(user_mode): assert f.f_code.co_name == "execute" assert tb_out.tb_next.tb_next is None + + +def test_traceback_contains_remote_call(): + stack = [ + ("/home/foobar/code/script.py", "f"), + ("/usr/local/venv/modal.py", "local"), + ] + + tb = tblib.Traceback.from_dict(tb_dict_from_stack_dicts(make_tb_stack(stack))) + assert not traceback_contains_remote_call(tb) + + task_id = "ta-0123456789ABCDEFGHILJKMNOP" + stack.extend( + [ + (f"<{task_id}>:/usr/local/lib/python3.11/importlib/__init__.py", ""), + ("/root/script.py", ""), + ] + ) + + tb = tblib.Traceback.from_dict(tb_dict_from_stack_dicts(make_tb_stack(stack))) + assert traceback_contains_remote_call(tb)