diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index 1817c8af4..7bf844bcd 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -699,7 +699,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client): _client: _Client = synchronizer._translate_in(client) # TODO(erikbern): ugly - with container_io_manager.heartbeats(), UserCodeEventLoop() as event_loop: + with container_io_manager.heartbeats(function_def.is_checkpointing_function), UserCodeEventLoop() as event_loop: # If this is a serialized function, fetch the definition from the server if function_def.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED: ser_cls, ser_fun = container_io_manager.get_serialized_function() @@ -779,6 +779,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client): if function_def.is_checkpointing_function: container_io_manager.memory_snapshot() + # Install hooks for interactive functions. if function_def.pty_info.pty_type != api_pb2.PTYInfo.PTY_TYPE_UNSPECIFIED: diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index fc8198445..ded25a955 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -73,8 +73,9 @@ class _ContainerIOManager: _input_concurrency: Optional[int] _semaphore: Optional[asyncio.Semaphore] _environment_name: str - _waiting_for_memory_snapshot: bool _heartbeat_loop: Optional[asyncio.Task] + _heartbeat_condition: asyncio.Condition + _waiting_for_memory_snapshot: bool _is_interactivity_enabled: bool _fetching_inputs: bool @@ -101,8 +102,9 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client): self._semaphore = None self._environment_name = container_args.environment_name - self._waiting_for_memory_snapshot = False self._heartbeat_loop = None + self._heartbeat_condition = asyncio.Condition() + self._waiting_for_memory_snapshot = False self._is_interactivity_enabled = False self._fetching_inputs = True @@ -146,22 +148,24 @@ async def _heartbeat_handle_cancellations(self) -> bool: # Return True if a cancellation event was received, in that case # we shouldn't wait too long for another heartbeat - # Don't send heartbeats for tasks waiting to be checkpointed. - # Calling gRPC methods open new connections which block the - # checkpointing process. - if self._waiting_for_memory_snapshot: - return False - request = api_pb2.ContainerHeartbeatRequest(supports_graceful_input_cancellation=True) if self.current_input_id is not None: request.current_input_id = self.current_input_id if self.current_input_started_at is not None: request.current_input_started_at = self.current_input_started_at - # TODO(erikbern): capture exceptions? - response = await retry_transient_errors( - self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT - ) + async with self._heartbeat_condition: + # Continuously wait until `waiting_for_memory_snapshot` is false. + # TODO(matt): Verify that a `while` is necessary over an `if`. Spurious + # wakeups could allow execution to continue despite `_waiting_for_memory_snapshot` + # being true. + while self._waiting_for_memory_snapshot: + await self._heartbeat_condition.wait() + + # TODO(erikbern): capture exceptions? + response = await retry_transient_errors( + self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT + ) if response.HasField("cancel_input_event"): # Pause processing of the current input by signaling self a SIGUSR1. @@ -196,10 +200,11 @@ async def _heartbeat_handle_cancellations(self) -> bool: return False @asynccontextmanager - async def heartbeats(self) -> AsyncGenerator[None, None]: + async def heartbeats(self, wait_for_mem_snap: bool) -> AsyncGenerator[None, None]: async with TaskContext() as tc: self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop()) t.set_name("heartbeat loop") + self._waiting_for_memory_snapshot = wait_for_mem_snap try: yield finally: @@ -617,22 +622,32 @@ async def memory_restore(self) -> None: ) self._client = await _Client.from_env() - self._waiting_for_memory_snapshot = False async def memory_snapshot(self) -> None: """Message server indicating that function is ready to be checkpointed.""" if self.checkpoint_id: logger.debug(f"Checkpoint ID: {self.checkpoint_id} (Memory Snapshot ID)") - await self._client.stub.ContainerCheckpoint( - api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) - ) + # Pause heartbeats since they keep the client connection open which causes the snapshotter to crash + async with self._heartbeat_condition: + # Notify the heartbeat loop that the snapshot phase has begun in order to + # prevent it from sending heartbeat RPCs + self._waiting_for_memory_snapshot = True + self._heartbeat_condition.notify_all() + + await self._client.stub.ContainerCheckpoint( + api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) + ) + + await self._client._close(forget_credentials=True) - self._waiting_for_memory_snapshot = True - await self._client._close(forget_credentials=True) + logger.debug("Memory snapshot request sent. Connection closed.") + await self.memory_restore() - logger.debug("Memory snapshot request sent. Connection closed.") - await self.memory_restore() + # Turn heartbeats back on. This is safe since the snapshot RPC + # and the restore phase has finished. + self._waiting_for_memory_snapshot = False + self._heartbeat_condition.notify_all() async def volume_commit(self, volume_ids: List[str]) -> None: """ diff --git a/test/container_app_test.py b/test/container_app_test.py index 90f753e02..15855e9d1 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -1,4 +1,5 @@ # Copyright Modal Labs 2022 +import asyncio import json import os import pytest @@ -9,7 +10,8 @@ from google.protobuf.message import Message from modal import App, interact -from modal._container_io_manager import ContainerIOManager +from modal._container_io_manager import ContainerIOManager, _ContainerIOManager +from modal.client import _Client from modal.running_app import RunningApp from modal_proto import api_pb2 @@ -17,6 +19,21 @@ def my_f_1(x): pass +def temp_restore_path(tmpdir): + # Write out a restore file so that snapshot+restore will complete + restore_path = tmpdir.join("fake-restore-state.json") + restore_path.write_text( + json.dumps( + { + "task_id": "ta-i-am-restored", + "task_secret": "ts-i-am-restored", + "function_id": "fu-i-am-restored", + } + ), + encoding="utf-8", + ) + return restore_path + @pytest.mark.asyncio async def test_container_function_lazily_imported(container_client): @@ -45,18 +62,7 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer): # Get a reference to a Client instance in memory old_client = container_client io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client) - restore_path = tmpdir.join("fake-restore-state.json") - # Write out a restore file so that snapshot+restore will complete - restore_path.write_text( - json.dumps( - { - "task_id": "ta-i-am-restored", - "task_secret": "ts-i-am-restored", - "function_id": "fu-i-am-restored", - } - ), - encoding="utf-8", - ) + restore_path = temp_restore_path(tmpdir) with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} ): @@ -65,6 +71,32 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer): assert old_client.credentials == ("ta-i-am-restored", "ts-i-am-restored") +@pytest.mark.asyncio +async def test_container_snapshot_restore_heartbeats(tmpdir, servicer): + client = _Client(servicer.container_addr, api_pb2.CLIENT_TYPE_CONTAINER, ("ta-123", "task-secret")) + async with client as async_client: + io_manager = _ContainerIOManager(api_pb2.ContainerArguments(), async_client) + restore_path = temp_restore_path(tmpdir) + + # Ensure that heartbeats only run after the snapshot + heartbeat_interval_secs = 0.01 + async with io_manager.heartbeats(True): + with mock.patch.dict( + os.environ, + {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}, + ): + with mock.patch("modal.runner.HEARTBEAT_INTERVAL", heartbeat_interval_secs): + await asyncio.sleep(heartbeat_interval_secs*2) + assert not list( + filter(lambda req: isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests) + ) + await io_manager.memory_snapshot() + await asyncio.sleep(heartbeat_interval_secs*2) + assert list( + filter(lambda req: isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests) + ) + + @pytest.mark.asyncio async def test_container_debug_snapshot(container_client, tmpdir, servicer): # Get an IO manager, where restore takes place @@ -121,7 +153,6 @@ def weird_torch_module(): @pytest.mark.asyncio async def test_container_snapshot_patching(fake_torch_module, container_client, tmpdir, servicer): io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client) - restore_path = tmpdir.join("fake-restore-state.json") # bring fake torch into scope and call the utility fn import torch @@ -129,16 +160,7 @@ async def test_container_snapshot_patching(fake_torch_module, container_client, assert torch.cuda.device_count() == 0 # Write out a restore file so that snapshot+restore will complete - restore_path.write_text( - json.dumps( - { - "task_id": "ta-i-am-restored", - "task_secret": "ts-i-am-restored", - "function_id": "fu-i-am-restored", - } - ), - encoding="utf-8", - ) + restore_path = temp_restore_path(tmpdir) with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} ): @@ -151,24 +173,13 @@ async def test_container_snapshot_patching(fake_torch_module, container_client, @pytest.mark.asyncio async def test_container_snapshot_patching_err(weird_torch_module, container_client, tmpdir, servicer): io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client) - restore_path = tmpdir.join("fake-restore-state.json") + restore_path = temp_restore_path(tmpdir) # bring weird torch into scope and call the utility fn import torch as trch assert trch.IM_WEIRD == 42 - # Write out a restore file so that snapshot+restore will complete - restore_path.write_text( - json.dumps( - { - "task_id": "ta-i-am-restored", - "task_secret": "ts-i-am-restored", - "function_id": "fu-i-am-restored", - } - ), - encoding="utf-8", - ) with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} ): diff --git a/test/container_test.py b/test/container_test.py index f69bd749b..a6b2702ec 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -790,7 +790,7 @@ def test_checkpointing_cls_function(servicer): ret = _run_container( servicer, "test.supports.functions", - "CheckpointingCls.*", + "SnapshottingCls.*", inputs=_get_inputs((("D",), {}), method_name="f"), is_checkpointing_function=True, is_class=True, @@ -819,6 +819,9 @@ def test_container_heartbeats(servicer): _run_container(servicer, "test.supports.functions", "square") assert any(isinstance(request, api_pb2.ContainerHeartbeatRequest) for request in servicer.requests) + _run_container(servicer, "test.supports.functions", "snapshotting_square") + assert any(isinstance(request, api_pb2.ContainerHeartbeatRequest) for request in servicer.requests) + @skip_github_non_linux def test_cli(servicer): diff --git a/test/supports/functions.py b/test/supports/functions.py index d2a9a84df..4fcc9618d 100644 --- a/test/supports/functions.py +++ b/test/supports/functions.py @@ -393,7 +393,7 @@ def f(self, x): @app.cls(enable_memory_snapshot=True) -class CheckpointingCls: +class SnapshottingCls: def __init__(self): self._vals = [] @@ -414,6 +414,11 @@ def f(self, x): return "".join(self._vals) + x +@app.function(enable_memory_snapshot=True) +def snapshotting_square(x): + return x * x + + @app.cls() class EventLoopCls: @enter()