From f3c04c600b65cdc607258edefdd69aa296bae03c Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Fri, 12 Jul 2024 13:27:16 +0000 Subject: [PATCH 01/22] Stop heartbeats before snapshotting --- modal/_container_io_manager.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 9cc97571a..eaa16416e 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -609,6 +609,9 @@ async def memory_snapshot(self) -> None: if self.checkpoint_id: logger.debug(f"Checkpoint ID: {self.checkpoint_id} (Memory Snapshot ID)") + # Heartbeats can leave the modal.sock file open, causing gVisor to crash + self.stop_heartbeat() + await self._client.stub.ContainerCheckpoint( api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) ) From 4cceb72125a3a1c758b158a774efd76b2d13f3ce Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Fri, 12 Jul 2024 16:43:44 +0000 Subject: [PATCH 02/22] Added asyncio.Event for pausing heartbeats --- modal/_container_io_manager.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index eaa16416e..f02060c7b 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -74,6 +74,7 @@ class _ContainerIOManager: _environment_name: str _waiting_for_memory_snapshot: bool _heartbeat_loop: Optional[asyncio.Task] + _pause_heartbeats: Optional[asyncio.Event] _is_interactivity_enabled: bool _fetching_inputs: bool @@ -102,6 +103,7 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client): self._environment_name = container_args.environment_name self._waiting_for_memory_snapshot = False self._heartbeat_loop = None + self._pause_heartbeats = None self._is_interactivity_enabled = False self._fetching_inputs = True @@ -157,6 +159,9 @@ async def _heartbeat_handle_cancellations(self) -> bool: if self.current_input_started_at is not None: request.current_input_started_at = self.current_input_started_at + # Wait until memory snapshotting finishes + await self._pause_heartbeats.wait() + # TODO(erikbern): capture exceptions? response = await retry_transient_errors( self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT @@ -198,6 +203,7 @@ async def _heartbeat_handle_cancellations(self) -> bool: async def heartbeats(self) -> AsyncGenerator[None, None]: async with TaskContext() as tc: self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop()) + self._pause_heartbeats = asyncio.Event() t.set_name("heartbeat loop") try: yield @@ -572,6 +578,7 @@ async def memory_restore(self) -> None: logger.debug(f"Waiting for restore (elapsed={time.perf_counter() - start:.3f}s)") await asyncio.sleep(0.01) continue + self._pause_heartbeats.set() # Resume heartbeats logger.debug("Container: restored") @@ -609,9 +616,8 @@ async def memory_snapshot(self) -> None: if self.checkpoint_id: logger.debug(f"Checkpoint ID: {self.checkpoint_id} (Memory Snapshot ID)") - # Heartbeats can leave the modal.sock file open, causing gVisor to crash - self.stop_heartbeat() - + # Pause heartbeats, since they keep the client connection open, causing the snapshotter to crash. + self._pause_heartbeats.clear() await self._client.stub.ContainerCheckpoint( api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) ) From 626af1c587865612d652a41a04fb1e2d0d9f4e5f Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Fri, 12 Jul 2024 18:24:02 +0000 Subject: [PATCH 03/22] Use asyncio.Condition to ensure mutual exclusion --- modal/_container_io_manager.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 2048dc225..71f4473ff 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -75,7 +75,7 @@ class _ContainerIOManager: _environment_name: str _waiting_for_memory_snapshot: bool _heartbeat_loop: Optional[asyncio.Task] - _pause_heartbeats: Optional[asyncio.Event] + _pause_heartbeats: Optional[asyncio.Condition] _is_interactivity_enabled: bool _fetching_inputs: bool @@ -160,13 +160,14 @@ async def _heartbeat_handle_cancellations(self) -> bool: if self.current_input_started_at is not None: request.current_input_started_at = self.current_input_started_at - # Wait until memory snapshotting finishes - await self._pause_heartbeats.wait() + async with self._pause_heartbeats: + await self._pause_heartbeats.wait() - # TODO(erikbern): capture exceptions? - response = await retry_transient_errors( - self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT - ) + # TODO(erikbern): capture exceptions? + response = await retry_transient_errors( + self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT + ) + print("SENT HEARTBEAT") if response.HasField("cancel_input_event"): # Pause processing of the current input by signaling self a SIGUSR1. @@ -204,7 +205,7 @@ async def _heartbeat_handle_cancellations(self) -> bool: async def heartbeats(self) -> AsyncGenerator[None, None]: async with TaskContext() as tc: self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop()) - self._pause_heartbeats = asyncio.Event() + self._pause_heartbeats = asyncio.Condition() t.set_name("heartbeat loop") try: yield @@ -579,7 +580,6 @@ async def memory_restore(self) -> None: logger.debug(f"Waiting for restore (elapsed={time.perf_counter() - start:.3f}s)") await asyncio.sleep(0.01) continue - self._pause_heartbeats.set() # Resume heartbeats logger.debug("Container: restored") @@ -632,10 +632,10 @@ async def memory_snapshot(self) -> None: logger.debug(f"Checkpoint ID: {self.checkpoint_id} (Memory Snapshot ID)") # Pause heartbeats, since they keep the client connection open, causing the snapshotter to crash. - self._pause_heartbeats.clear() - await self._client.stub.ContainerCheckpoint( - api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) - ) + async with self._pause_heartbeats: + await self._client.stub.ContainerCheckpoint( + api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) + ) self._waiting_for_memory_snapshot = True await self._client._close(forget_credentials=True) From 72286520563faf4e078c45457589b0b71edb7550 Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Fri, 12 Jul 2024 19:05:02 +0000 Subject: [PATCH 04/22] Fixed unit tests --- modal/_container_io_manager.py | 4 +--- test/container_app_test.py | 24 ++++++++++++++---------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 71f4473ff..08096f2ee 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -162,12 +162,10 @@ async def _heartbeat_handle_cancellations(self) -> bool: async with self._pause_heartbeats: await self._pause_heartbeats.wait() - # TODO(erikbern): capture exceptions? response = await retry_transient_errors( self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT ) - print("SENT HEARTBEAT") if response.HasField("cancel_input_event"): # Pause processing of the current input by signaling self a SIGUSR1. @@ -631,7 +629,7 @@ async def memory_snapshot(self) -> None: if self.checkpoint_id: logger.debug(f"Checkpoint ID: {self.checkpoint_id} (Memory Snapshot ID)") - # Pause heartbeats, since they keep the client connection open, causing the snapshotter to crash. + # Pause heartbeats since they keep the client connection open which causes the snapshotter to crash async with self._pause_heartbeats: await self._client.stub.ContainerCheckpoint( api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) diff --git a/test/container_app_test.py b/test/container_app_test.py index 90f753e02..345686221 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -58,11 +58,12 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer): encoding="utf-8", ) with mock.patch.dict( - os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} + os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}, ): - io_manager.memory_snapshot() - # In-memory Client instance should have update credentials, not old credentials - assert old_client.credentials == ("ta-i-am-restored", "ts-i-am-restored") + with io_manager.heartbeats(): + io_manager.memory_snapshot() + # In-memory Client instance should have update credentials, not old credentials + assert old_client.credentials == ("ta-i-am-restored", "ts-i-am-restored") @pytest.mark.asyncio @@ -82,8 +83,9 @@ async def test_container_debug_snapshot(container_client, tmpdir, servicer): with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} ): - io_manager.memory_snapshot() - test_breakpoint.assert_called_once() + with io_manager.heartbeats(): + io_manager.memory_snapshot() + test_breakpoint.assert_called_once() @pytest.fixture(scope="function") @@ -142,10 +144,11 @@ async def test_container_snapshot_patching(fake_torch_module, container_client, with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} ): - io_manager.memory_snapshot() - import torch + with io_manager.heartbeats(): + io_manager.memory_snapshot() + import torch - assert torch.cuda.device_count() == 2 + assert torch.cuda.device_count() == 2 @pytest.mark.asyncio @@ -172,7 +175,8 @@ async def test_container_snapshot_patching_err(weird_torch_module, container_cli with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} ): - io_manager.memory_snapshot() # should not crash + with io_manager.heartbeats(): + io_manager.memory_snapshot() # should not crash def test_interact(container_client, servicer): From 86b8fd5f41b674b56b7c6b6dda6ec7d8e3903f5d Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Mon, 15 Jul 2024 16:24:18 +0000 Subject: [PATCH 05/22] Fixed bug, but now its slow --- modal/_container_io_manager.py | 30 ++++++++++++++++++++++++------ test/container_app_test.py | 22 +++++++++------------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 08096f2ee..c35ca5bc3 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -76,6 +76,7 @@ class _ContainerIOManager: _waiting_for_memory_snapshot: bool _heartbeat_loop: Optional[asyncio.Task] _pause_heartbeats: Optional[asyncio.Condition] + _snapshot_running: bool _is_interactivity_enabled: bool _fetching_inputs: bool @@ -104,7 +105,8 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client): self._environment_name = container_args.environment_name self._waiting_for_memory_snapshot = False self._heartbeat_loop = None - self._pause_heartbeats = None + self._pause_heartbeats = asyncio.Condition() + self._snapshot_running = False self._is_interactivity_enabled = False self._fetching_inputs = True @@ -161,11 +163,15 @@ async def _heartbeat_handle_cancellations(self) -> bool: request.current_input_started_at = self.current_input_started_at async with self._pause_heartbeats: - await self._pause_heartbeats.wait() + print("heartbeat acquired lock") + while self._snapshot_running: + await self._pause_heartbeats.wait() + # TODO(erikbern): capture exceptions? response = await retry_transient_errors( self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT ) + print("heartbeat released lock") if response.HasField("cancel_input_event"): # Pause processing of the current input by signaling self a SIGUSR1. @@ -203,7 +209,7 @@ async def _heartbeat_handle_cancellations(self) -> bool: async def heartbeats(self) -> AsyncGenerator[None, None]: async with TaskContext() as tc: self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop()) - self._pause_heartbeats = asyncio.Condition() + self._snapshot_running = False t.set_name("heartbeat loop") try: yield @@ -579,6 +585,13 @@ async def memory_restore(self) -> None: await asyncio.sleep(0.01) continue + # Turn heartbeats back on + async with self._pause_heartbeats: + print("restore acquired lock") + self._snapshot_running = False + self._pause_heartbeats.notify_all() + print("restore released lock") + logger.debug("Container: restored") # Look for state file and create new client with updated credentials. @@ -631,9 +644,14 @@ async def memory_snapshot(self) -> None: # Pause heartbeats since they keep the client connection open which causes the snapshotter to crash async with self._pause_heartbeats: - await self._client.stub.ContainerCheckpoint( - api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) - ) + print("snapshot acquired lock") + self._snapshot_running = True + self._pause_heartbeats.notify_all() + + await self._client.stub.ContainerCheckpoint( + api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) + ) + print("snapshot sent request") self._waiting_for_memory_snapshot = True await self._client._close(forget_credentials=True) diff --git a/test/container_app_test.py b/test/container_app_test.py index 345686221..b4eff53f6 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -60,10 +60,9 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer): with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}, ): - with io_manager.heartbeats(): - io_manager.memory_snapshot() - # In-memory Client instance should have update credentials, not old credentials - assert old_client.credentials == ("ta-i-am-restored", "ts-i-am-restored") + io_manager.memory_snapshot() + # In-memory Client instance should have update credentials, not old credentials + assert old_client.credentials == ("ta-i-am-restored", "ts-i-am-restored") @pytest.mark.asyncio @@ -83,9 +82,8 @@ async def test_container_debug_snapshot(container_client, tmpdir, servicer): with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} ): - with io_manager.heartbeats(): - io_manager.memory_snapshot() - test_breakpoint.assert_called_once() + io_manager.memory_snapshot() + test_breakpoint.assert_called_once() @pytest.fixture(scope="function") @@ -144,11 +142,10 @@ async def test_container_snapshot_patching(fake_torch_module, container_client, with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} ): - with io_manager.heartbeats(): - io_manager.memory_snapshot() - import torch + io_manager.memory_snapshot() + import torch - assert torch.cuda.device_count() == 2 + assert torch.cuda.device_count() == 2 @pytest.mark.asyncio @@ -175,8 +172,7 @@ async def test_container_snapshot_patching_err(weird_torch_module, container_cli with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} ): - with io_manager.heartbeats(): - io_manager.memory_snapshot() # should not crash + io_manager.memory_snapshot() # should not crash def test_interact(container_client, servicer): From 4c6ed72effee13e628d1f7a9209f1d6adb20225b Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Mon, 15 Jul 2024 18:10:32 +0000 Subject: [PATCH 06/22] Fixed bottleneck --- modal/_container_entrypoint.py | 3 ++- modal/_container_io_manager.py | 11 ++--------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/modal/_container_entrypoint.py b/modal/_container_entrypoint.py index 1817c8af4..7bf844bcd 100644 --- a/modal/_container_entrypoint.py +++ b/modal/_container_entrypoint.py @@ -699,7 +699,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client): _client: _Client = synchronizer._translate_in(client) # TODO(erikbern): ugly - with container_io_manager.heartbeats(), UserCodeEventLoop() as event_loop: + with container_io_manager.heartbeats(function_def.is_checkpointing_function), UserCodeEventLoop() as event_loop: # If this is a serialized function, fetch the definition from the server if function_def.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED: ser_cls, ser_fun = container_io_manager.get_serialized_function() @@ -779,6 +779,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client): if function_def.is_checkpointing_function: container_io_manager.memory_snapshot() + # Install hooks for interactive functions. if function_def.pty_info.pty_type != api_pb2.PTYInfo.PTY_TYPE_UNSPECIFIED: diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index c35ca5bc3..6df817a17 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -106,7 +106,6 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client): self._waiting_for_memory_snapshot = False self._heartbeat_loop = None self._pause_heartbeats = asyncio.Condition() - self._snapshot_running = False self._is_interactivity_enabled = False self._fetching_inputs = True @@ -163,7 +162,6 @@ async def _heartbeat_handle_cancellations(self) -> bool: request.current_input_started_at = self.current_input_started_at async with self._pause_heartbeats: - print("heartbeat acquired lock") while self._snapshot_running: await self._pause_heartbeats.wait() @@ -171,7 +169,6 @@ async def _heartbeat_handle_cancellations(self) -> bool: response = await retry_transient_errors( self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT ) - print("heartbeat released lock") if response.HasField("cancel_input_event"): # Pause processing of the current input by signaling self a SIGUSR1. @@ -206,11 +203,11 @@ async def _heartbeat_handle_cancellations(self) -> bool: return False @asynccontextmanager - async def heartbeats(self) -> AsyncGenerator[None, None]: + async def heartbeats(self, enable: bool) -> AsyncGenerator[None, None]: async with TaskContext() as tc: self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop()) - self._snapshot_running = False t.set_name("heartbeat loop") + self._snapshot_running = enable try: yield finally: @@ -587,10 +584,8 @@ async def memory_restore(self) -> None: # Turn heartbeats back on async with self._pause_heartbeats: - print("restore acquired lock") self._snapshot_running = False self._pause_heartbeats.notify_all() - print("restore released lock") logger.debug("Container: restored") @@ -644,14 +639,12 @@ async def memory_snapshot(self) -> None: # Pause heartbeats since they keep the client connection open which causes the snapshotter to crash async with self._pause_heartbeats: - print("snapshot acquired lock") self._snapshot_running = True self._pause_heartbeats.notify_all() await self._client.stub.ContainerCheckpoint( api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) ) - print("snapshot sent request") self._waiting_for_memory_snapshot = True await self._client._close(forget_credentials=True) From 096a39922af2fc442dc9dde6d3a9c40230175f60 Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Mon, 15 Jul 2024 18:24:05 +0000 Subject: [PATCH 07/22] Remove old logic --- modal/_container_io_manager.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 6df817a17..6ab0a1b51 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -73,7 +73,6 @@ class _ContainerIOManager: _input_concurrency: Optional[int] _semaphore: Optional[asyncio.Semaphore] _environment_name: str - _waiting_for_memory_snapshot: bool _heartbeat_loop: Optional[asyncio.Task] _pause_heartbeats: Optional[asyncio.Condition] _snapshot_running: bool @@ -103,7 +102,6 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client): self._semaphore = None self._environment_name = container_args.environment_name - self._waiting_for_memory_snapshot = False self._heartbeat_loop = None self._pause_heartbeats = asyncio.Condition() @@ -149,12 +147,6 @@ async def _heartbeat_handle_cancellations(self) -> bool: # Return True if a cancellation event was received, in that case # we shouldn't wait too long for another heartbeat - # Don't send heartbeats for tasks waiting to be checkpointed. - # Calling gRPC methods open new connections which block the - # checkpointing process. - if self._waiting_for_memory_snapshot: - return False - request = api_pb2.ContainerHeartbeatRequest(supports_graceful_input_cancellation=True) if self.current_input_id is not None: request.current_input_id = self.current_input_id @@ -630,7 +622,6 @@ async def memory_restore(self) -> None: ) self._client = await _Client.from_env() - self._waiting_for_memory_snapshot = False async def memory_snapshot(self) -> None: """Message server indicating that function is ready to be checkpointed.""" @@ -646,7 +637,6 @@ async def memory_snapshot(self) -> None: api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) ) - self._waiting_for_memory_snapshot = True await self._client._close(forget_credentials=True) logger.debug("Memory snapshot request sent. Connection closed.") From bf5400b48cbe62f2ccd94d8b0c0f0d8af52f2cea Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Mon, 15 Jul 2024 18:28:00 +0000 Subject: [PATCH 08/22] Renamed cond var --- modal/_container_io_manager.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 6ab0a1b51..f3820f39b 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -75,7 +75,7 @@ class _ContainerIOManager: _environment_name: str _heartbeat_loop: Optional[asyncio.Task] _pause_heartbeats: Optional[asyncio.Condition] - _snapshot_running: bool + _waiting_for_memory_snapshot: bool _is_interactivity_enabled: bool _fetching_inputs: bool @@ -104,6 +104,7 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client): self._environment_name = container_args.environment_name self._heartbeat_loop = None self._pause_heartbeats = asyncio.Condition() + self._waiting_for_memory_snapshot = False self._is_interactivity_enabled = False self._fetching_inputs = True @@ -154,7 +155,7 @@ async def _heartbeat_handle_cancellations(self) -> bool: request.current_input_started_at = self.current_input_started_at async with self._pause_heartbeats: - while self._snapshot_running: + while self._waiting_for_memory_snapshot: await self._pause_heartbeats.wait() # TODO(erikbern): capture exceptions? @@ -199,7 +200,7 @@ async def heartbeats(self, enable: bool) -> AsyncGenerator[None, None]: async with TaskContext() as tc: self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop()) t.set_name("heartbeat loop") - self._snapshot_running = enable + self._waiting_for_memory_snapshot = enable try: yield finally: @@ -576,7 +577,7 @@ async def memory_restore(self) -> None: # Turn heartbeats back on async with self._pause_heartbeats: - self._snapshot_running = False + self._waiting_for_memory_snapshot = False self._pause_heartbeats.notify_all() logger.debug("Container: restored") @@ -630,7 +631,7 @@ async def memory_snapshot(self) -> None: # Pause heartbeats since they keep the client connection open which causes the snapshotter to crash async with self._pause_heartbeats: - self._snapshot_running = True + self._waiting_for_memory_snapshot = True self._pause_heartbeats.notify_all() await self._client.stub.ContainerCheckpoint( From 484e24da3b20479dde14ad395e1c9a3ffa24772b Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Mon, 15 Jul 2024 19:20:51 +0000 Subject: [PATCH 09/22] Wrote tests --- modal/_container_io_manager.py | 4 ++-- test/container_app_test.py | 27 +++++++++++++++++++++++++++ test/container_test.py | 3 +++ test/live_reload_test.py | 1 + test/supports/functions.py | 5 +++++ 5 files changed, 38 insertions(+), 2 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index f3820f39b..db227856e 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -196,11 +196,11 @@ async def _heartbeat_handle_cancellations(self) -> bool: return False @asynccontextmanager - async def heartbeats(self, enable: bool) -> AsyncGenerator[None, None]: + async def heartbeats(self, disable_init: bool) -> AsyncGenerator[None, None]: async with TaskContext() as tc: self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop()) t.set_name("heartbeat loop") - self._waiting_for_memory_snapshot = enable + self._waiting_for_memory_snapshot = disable_init try: yield finally: diff --git a/test/container_app_test.py b/test/container_app_test.py index b4eff53f6..33f9284d5 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -64,6 +64,33 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer): # In-memory Client instance should have update credentials, not old credentials assert old_client.credentials == ("ta-i-am-restored", "ts-i-am-restored") +@pytest.mark.asyncio +async def test_container_snapshot_restore_heartbeats(container_client, tmpdir, servicer): + # Get a reference to a Client instance in memory + io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client) + restore_path = tmpdir.join("fake-restore-state.json") + # Write out a restore file so that snapshot+restore will complete + restore_path.write_text( + json.dumps( + { + "task_id": "ta-i-am-restored", + "task_secret": "ts-i-am-restored", + "function_id": "fu-i-am-restored", + } + ), + encoding="utf-8", + ) + + with io_manager.heartbeats(True): + with mock.patch.dict( + os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}, + ): + with mock.patch("modal.runner.HEARTBEAT_INTERVAL", 0.5): + with mock.patch.object(container_client.stub, 'ContainerHeartbeat') as mock_heartbeat: + io_manager.memory_snapshot() + assert not mock_heartbeat.assert_called_once() + + @pytest.mark.asyncio async def test_container_debug_snapshot(container_client, tmpdir, servicer): diff --git a/test/container_test.py b/test/container_test.py index f69bd749b..d067d5cdf 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -819,6 +819,9 @@ def test_container_heartbeats(servicer): _run_container(servicer, "test.supports.functions", "square") assert any(isinstance(request, api_pb2.ContainerHeartbeatRequest) for request in servicer.requests) + _run_container(servicer, "test.supports.functions", "snapshotting_square") + assert any(isinstance(request, api_pb2.ContainerHeartbeatRequest) for request in servicer.requests) + @skip_github_non_linux def test_cli(servicer): diff --git a/test/live_reload_test.py b/test/live_reload_test.py index 873f9a382..01f427fec 100644 --- a/test/live_reload_test.py +++ b/test/live_reload_test.py @@ -78,3 +78,4 @@ async def test_heartbeats(app_ref, server_url_env, servicer): # Typically [0s, 1s, 2s, 3s], but asyncio.sleep may lag. actual_heartbeats = servicer.app_heartbeats[apps[0]] assert abs(actual_heartbeats - (total_secs + 1)) <= 1 + diff --git a/test/supports/functions.py b/test/supports/functions.py index d2a9a84df..1ef67e2f2 100644 --- a/test/supports/functions.py +++ b/test/supports/functions.py @@ -414,6 +414,11 @@ def f(self, x): return "".join(self._vals) + x +@app.function(enable_memory_snapshot=True) +def snapshotting_square(x): + return x * x + + @app.cls() class EventLoopCls: @enter() From b6fa509ba2c918f36bfbeefd4b3ee9a8cb8c3db5 Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Mon, 15 Jul 2024 19:42:36 +0000 Subject: [PATCH 10/22] Wrote better tests --- test/container_app_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/container_app_test.py b/test/container_app_test.py index 33f9284d5..2805777d1 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -66,10 +66,8 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer): @pytest.mark.asyncio async def test_container_snapshot_restore_heartbeats(container_client, tmpdir, servicer): - # Get a reference to a Client instance in memory io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client) restore_path = tmpdir.join("fake-restore-state.json") - # Write out a restore file so that snapshot+restore will complete restore_path.write_text( json.dumps( { @@ -81,15 +79,17 @@ async def test_container_snapshot_restore_heartbeats(container_client, tmpdir, s encoding="utf-8", ) + # Ensure that heartbeats do not run before the snapshot + # Ensure that heartbeats do run after the snapshot with io_manager.heartbeats(True): with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}, ): - with mock.patch("modal.runner.HEARTBEAT_INTERVAL", 0.5): + with mock.patch("modal.runner.HEARTBEAT_INTERVAL", 1): with mock.patch.object(container_client.stub, 'ContainerHeartbeat') as mock_heartbeat: + mock_heartbeat.assert_not_called() io_manager.memory_snapshot() - assert not mock_heartbeat.assert_called_once() - + mock_heartbeat.assert_called_once() @pytest.mark.asyncio From 6071891db1d2a46c06128476a4d9f19861f24b77 Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Mon, 15 Jul 2024 19:57:24 +0000 Subject: [PATCH 11/22] Try to fix tests --- test/container_app_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/container_app_test.py b/test/container_app_test.py index 2805777d1..9e4f87760 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -4,6 +4,7 @@ import pytest from typing import Dict from unittest import mock +import asyncio from google.protobuf.empty_pb2 import Empty from google.protobuf.message import Message @@ -79,6 +80,8 @@ async def test_container_snapshot_restore_heartbeats(container_client, tmpdir, s encoding="utf-8", ) + asyncio.set_event_loop(asyncio.new_event_loop()) + # Ensure that heartbeats do not run before the snapshot # Ensure that heartbeats do run after the snapshot with io_manager.heartbeats(True): @@ -87,6 +90,8 @@ async def test_container_snapshot_restore_heartbeats(container_client, tmpdir, s ): with mock.patch("modal.runner.HEARTBEAT_INTERVAL", 1): with mock.patch.object(container_client.stub, 'ContainerHeartbeat') as mock_heartbeat: + + await asyncio.sleep(2) mock_heartbeat.assert_not_called() io_manager.memory_snapshot() mock_heartbeat.assert_called_once() From 70c49ff2230a7ab72091b79db500be531666b36c Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Mon, 15 Jul 2024 20:17:45 +0000 Subject: [PATCH 12/22] Retry --- test/async_utils_test.py | 2 +- test/container_app_test.py | 4 ++-- test/live_reload_test.py | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/test/async_utils_test.py b/test/async_utils_test.py index d7e047b84..cd82029ad 100644 --- a/test/async_utils_test.py +++ b/test/async_utils_test.py @@ -183,7 +183,7 @@ async def my_generator(): assert "list" in caplog.text -@pytest.mark.asyncio +@pytest.mark def test_warn_if_generator_is_not_consumed_sync(caplog): @warn_if_generator_is_not_consumed() def my_generator(): diff --git a/test/container_app_test.py b/test/container_app_test.py index 9e4f87760..6c947b5fc 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -1,10 +1,10 @@ # Copyright Modal Labs 2022 +import asyncio import json import os import pytest from typing import Dict from unittest import mock -import asyncio from google.protobuf.empty_pb2 import Empty from google.protobuf.message import Message @@ -59,7 +59,7 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer): encoding="utf-8", ) with mock.patch.dict( - os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}, + os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} ): io_manager.memory_snapshot() # In-memory Client instance should have update credentials, not old credentials diff --git a/test/live_reload_test.py b/test/live_reload_test.py index 01f427fec..873f9a382 100644 --- a/test/live_reload_test.py +++ b/test/live_reload_test.py @@ -78,4 +78,3 @@ async def test_heartbeats(app_ref, server_url_env, servicer): # Typically [0s, 1s, 2s, 3s], but asyncio.sleep may lag. actual_heartbeats = servicer.app_heartbeats[apps[0]] assert abs(actual_heartbeats - (total_secs + 1)) <= 1 - From 5b423f57cde0342c2620932738f1998911bd8c84 Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Mon, 15 Jul 2024 20:22:30 +0000 Subject: [PATCH 13/22] Undo warning --- test/async_utils_test.py | 2 +- test/container_test.py | 2 +- test/supports/functions.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/async_utils_test.py b/test/async_utils_test.py index cd82029ad..d7e047b84 100644 --- a/test/async_utils_test.py +++ b/test/async_utils_test.py @@ -183,7 +183,7 @@ async def my_generator(): assert "list" in caplog.text -@pytest.mark +@pytest.mark.asyncio def test_warn_if_generator_is_not_consumed_sync(caplog): @warn_if_generator_is_not_consumed() def my_generator(): diff --git a/test/container_test.py b/test/container_test.py index d067d5cdf..a6b2702ec 100644 --- a/test/container_test.py +++ b/test/container_test.py @@ -790,7 +790,7 @@ def test_checkpointing_cls_function(servicer): ret = _run_container( servicer, "test.supports.functions", - "CheckpointingCls.*", + "SnapshottingCls.*", inputs=_get_inputs((("D",), {}), method_name="f"), is_checkpointing_function=True, is_class=True, diff --git a/test/supports/functions.py b/test/supports/functions.py index 1ef67e2f2..4fcc9618d 100644 --- a/test/supports/functions.py +++ b/test/supports/functions.py @@ -393,7 +393,7 @@ def f(self, x): @app.cls(enable_memory_snapshot=True) -class CheckpointingCls: +class SnapshottingCls: def __init__(self): self._vals = [] From 817709fddbfdb47bcfaf25acbfa09dfd6d07a876 Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Mon, 15 Jul 2024 21:08:48 +0000 Subject: [PATCH 14/22] Await --- test/container_app_test.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/test/container_app_test.py b/test/container_app_test.py index 6c947b5fc..67e61e39f 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -80,8 +80,6 @@ async def test_container_snapshot_restore_heartbeats(container_client, tmpdir, s encoding="utf-8", ) - asyncio.set_event_loop(asyncio.new_event_loop()) - # Ensure that heartbeats do not run before the snapshot # Ensure that heartbeats do run after the snapshot with io_manager.heartbeats(True): @@ -90,10 +88,8 @@ async def test_container_snapshot_restore_heartbeats(container_client, tmpdir, s ): with mock.patch("modal.runner.HEARTBEAT_INTERVAL", 1): with mock.patch.object(container_client.stub, 'ContainerHeartbeat') as mock_heartbeat: - - await asyncio.sleep(2) mock_heartbeat.assert_not_called() - io_manager.memory_snapshot() + await io_manager.memory_snapshot() mock_heartbeat.assert_called_once() From 3259b2b4d7c0ee512a742a82a75309e7bdfc7cca Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Mon, 15 Jul 2024 21:42:56 +0000 Subject: [PATCH 15/22] Try using async client --- test/container_app_test.py | 52 ++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/test/container_app_test.py b/test/container_app_test.py index 67e61e39f..51e9b9a89 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -2,6 +2,7 @@ import asyncio import json import os +from modal.client import _Client import pytest from typing import Dict from unittest import mock @@ -10,7 +11,7 @@ from google.protobuf.message import Message from modal import App, interact -from modal._container_io_manager import ContainerIOManager +from modal._container_io_manager import ContainerIOManager, _ContainerIOManager from modal.running_app import RunningApp from modal_proto import api_pb2 @@ -66,31 +67,32 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer): assert old_client.credentials == ("ta-i-am-restored", "ts-i-am-restored") @pytest.mark.asyncio -async def test_container_snapshot_restore_heartbeats(container_client, tmpdir, servicer): - io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client) - restore_path = tmpdir.join("fake-restore-state.json") - restore_path.write_text( - json.dumps( - { - "task_id": "ta-i-am-restored", - "task_secret": "ts-i-am-restored", - "function_id": "fu-i-am-restored", - } - ), - encoding="utf-8", - ) +async def test_container_snapshot_restore_heartbeats(tmpdir, servicer): + async with _Client(servicer.container_addr, api_pb2.CLIENT_TYPE_CONTAINER, ("ta-123", "task-secret")) as async_client: + io_manager = _ContainerIOManager(api_pb2.ContainerArguments(), async_client) + restore_path = tmpdir.join("fake-restore-state.json") + restore_path.write_text( + json.dumps( + { + "task_id": "ta-i-am-restored", + "task_secret": "ts-i-am-restored", + "function_id": "fu-i-am-restored", + } + ), + encoding="utf-8", + ) - # Ensure that heartbeats do not run before the snapshot - # Ensure that heartbeats do run after the snapshot - with io_manager.heartbeats(True): - with mock.patch.dict( - os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}, - ): - with mock.patch("modal.runner.HEARTBEAT_INTERVAL", 1): - with mock.patch.object(container_client.stub, 'ContainerHeartbeat') as mock_heartbeat: - mock_heartbeat.assert_not_called() - await io_manager.memory_snapshot() - mock_heartbeat.assert_called_once() + # Ensure that heartbeats do not run before the snapshot + # Ensure that heartbeats do run after the snapshot + async with io_manager.heartbeats(True): + with mock.patch.dict( + os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}, + ): + with mock.patch("modal.runner.HEARTBEAT_INTERVAL", 1): + with mock.patch.object(async_client.stub, 'ContainerHeartbeat') as mock_heartbeat: + mock_heartbeat.assert_not_called() + await io_manager.memory_snapshot() + mock_heartbeat.assert_called_once() @pytest.mark.asyncio From 3c3209e867cd63846036866b2065891f12f02889 Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Mon, 15 Jul 2024 21:55:53 +0000 Subject: [PATCH 16/22] Use servicer --- test/conftest.py | 2 ++ test/container_app_test.py | 7 +++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index 00be2f2ac..946a39264 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -106,6 +106,7 @@ def __init__(self, blob_host, blobs): self.n_queue_heartbeats = 0 self.n_nfs_heartbeats = 0 self.n_vol_heartbeats = 0 + self.n_sent_heartbeats = 0 self.n_mounts = 0 self.n_mount_files = 0 self.mount_contents = {} @@ -464,6 +465,7 @@ async def ContainerHeartbeat(self, stream): if self.container_heartbeat_response: await stream.send_message(self.container_heartbeat_response) self.container_heartbeat_response = None + self.n_sent_heartbeats += 1 else: await stream.send_message(api_pb2.ContainerHeartbeatResponse()) diff --git a/test/container_app_test.py b/test/container_app_test.py index 51e9b9a89..7136f6820 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -89,10 +89,9 @@ async def test_container_snapshot_restore_heartbeats(tmpdir, servicer): os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}, ): with mock.patch("modal.runner.HEARTBEAT_INTERVAL", 1): - with mock.patch.object(async_client.stub, 'ContainerHeartbeat') as mock_heartbeat: - mock_heartbeat.assert_not_called() - await io_manager.memory_snapshot() - mock_heartbeat.assert_called_once() + assert servicer.n_sent_heartbeats == 0 + await io_manager.memory_snapshot() + assert servicer.n_sent_heartbeats > 0 @pytest.mark.asyncio From a9070ff0766e2b58af20bf5a54c53aeb4c1a291d Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Mon, 15 Jul 2024 22:03:47 +0000 Subject: [PATCH 17/22] Use servicer --- test/conftest.py | 2 -- test/container_app_test.py | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index 946a39264..00be2f2ac 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -106,7 +106,6 @@ def __init__(self, blob_host, blobs): self.n_queue_heartbeats = 0 self.n_nfs_heartbeats = 0 self.n_vol_heartbeats = 0 - self.n_sent_heartbeats = 0 self.n_mounts = 0 self.n_mount_files = 0 self.mount_contents = {} @@ -465,7 +464,6 @@ async def ContainerHeartbeat(self, stream): if self.container_heartbeat_response: await stream.send_message(self.container_heartbeat_response) self.container_heartbeat_response = None - self.n_sent_heartbeats += 1 else: await stream.send_message(api_pb2.ContainerHeartbeatResponse()) diff --git a/test/container_app_test.py b/test/container_app_test.py index 7136f6820..e4aa859ae 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -89,9 +89,9 @@ async def test_container_snapshot_restore_heartbeats(tmpdir, servicer): os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}, ): with mock.patch("modal.runner.HEARTBEAT_INTERVAL", 1): - assert servicer.n_sent_heartbeats == 0 + assert not list(filter(lambda req : isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests)) await io_manager.memory_snapshot() - assert servicer.n_sent_heartbeats > 0 + assert list(filter(lambda req : isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests)) @pytest.mark.asyncio From 153eb0a878311c9fa77956479cad992125500ace Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Mon, 15 Jul 2024 22:09:33 +0000 Subject: [PATCH 18/22] Lint --- test/container_app_test.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/test/container_app_test.py b/test/container_app_test.py index e4aa859ae..84595f026 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -1,8 +1,6 @@ # Copyright Modal Labs 2022 -import asyncio import json import os -from modal.client import _Client import pytest from typing import Dict from unittest import mock @@ -12,6 +10,7 @@ from modal import App, interact from modal._container_io_manager import ContainerIOManager, _ContainerIOManager +from modal.client import _Client from modal.running_app import RunningApp from modal_proto import api_pb2 @@ -66,9 +65,11 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer): # In-memory Client instance should have update credentials, not old credentials assert old_client.credentials == ("ta-i-am-restored", "ts-i-am-restored") + @pytest.mark.asyncio async def test_container_snapshot_restore_heartbeats(tmpdir, servicer): - async with _Client(servicer.container_addr, api_pb2.CLIENT_TYPE_CONTAINER, ("ta-123", "task-secret")) as async_client: + client = _Client(servicer.container_addr, api_pb2.CLIENT_TYPE_CONTAINER, ("ta-123", "task-secret")) + async with client as async_client: io_manager = _ContainerIOManager(api_pb2.ContainerArguments(), async_client) restore_path = tmpdir.join("fake-restore-state.json") restore_path.write_text( @@ -86,12 +87,17 @@ async def test_container_snapshot_restore_heartbeats(tmpdir, servicer): # Ensure that heartbeats do run after the snapshot async with io_manager.heartbeats(True): with mock.patch.dict( - os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}, + os.environ, + {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}, ): with mock.patch("modal.runner.HEARTBEAT_INTERVAL", 1): - assert not list(filter(lambda req : isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests)) + assert not list( + filter(lambda req: isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests) + ) await io_manager.memory_snapshot() - assert list(filter(lambda req : isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests)) + assert list( + filter(lambda req: isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests) + ) @pytest.mark.asyncio From 5bb40448be1e1d11091e3dec15f2e17b58ec3421 Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Mon, 15 Jul 2024 22:23:43 +0000 Subject: [PATCH 19/22] Add sleep in test --- test/container_app_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/container_app_test.py b/test/container_app_test.py index 84595f026..9dffaf0f4 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -1,4 +1,5 @@ # Copyright Modal Labs 2022 +import asyncio import json import os import pytest @@ -95,6 +96,7 @@ async def test_container_snapshot_restore_heartbeats(tmpdir, servicer): filter(lambda req: isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests) ) await io_manager.memory_snapshot() + await asyncio.sleep(1) assert list( filter(lambda req: isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests) ) From 91b56658799b92929aa7db9498034045397ccf63 Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Mon, 15 Jul 2024 22:40:48 +0000 Subject: [PATCH 20/22] Reduce sleep time --- test/container_app_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/container_app_test.py b/test/container_app_test.py index 9dffaf0f4..7d9a86b32 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -91,12 +91,12 @@ async def test_container_snapshot_restore_heartbeats(tmpdir, servicer): os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}, ): - with mock.patch("modal.runner.HEARTBEAT_INTERVAL", 1): + with mock.patch("modal.runner.HEARTBEAT_INTERVAL", 0.01): assert not list( filter(lambda req: isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests) ) await io_manager.memory_snapshot() - await asyncio.sleep(1) + await asyncio.sleep(0.01) assert list( filter(lambda req: isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests) ) From 74ae84a0fec64c4aa57b05f471a7883977c61816 Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Tue, 16 Jul 2024 14:42:19 +0000 Subject: [PATCH 21/22] Address review --- modal/_container_io_manager.py | 25 +++++++----- test/container_app_test.py | 74 +++++++++++----------------------- 2 files changed, 38 insertions(+), 61 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index db227856e..9c7c3c8cb 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -74,7 +74,7 @@ class _ContainerIOManager: _semaphore: Optional[asyncio.Semaphore] _environment_name: str _heartbeat_loop: Optional[asyncio.Task] - _pause_heartbeats: Optional[asyncio.Condition] + _heartbeat_condition: asyncio.Condition _waiting_for_memory_snapshot: bool _is_interactivity_enabled: bool @@ -103,7 +103,7 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client): self._semaphore = None self._environment_name = container_args.environment_name self._heartbeat_loop = None - self._pause_heartbeats = asyncio.Condition() + self._heartbeat_condition = asyncio.Condition() self._waiting_for_memory_snapshot = False self._is_interactivity_enabled = False @@ -154,9 +154,11 @@ async def _heartbeat_handle_cancellations(self) -> bool: if self.current_input_started_at is not None: request.current_input_started_at = self.current_input_started_at - async with self._pause_heartbeats: + async with self._heartbeat_condition: + # Continuously wait until `waiting_for_memory_snapshot` is false. More efficient + # than a busy-wait since `.wait()` yields control while self._waiting_for_memory_snapshot: - await self._pause_heartbeats.wait() + await self._heartbeat_condition.wait() # TODO(erikbern): capture exceptions? response = await retry_transient_errors( @@ -196,11 +198,11 @@ async def _heartbeat_handle_cancellations(self) -> bool: return False @asynccontextmanager - async def heartbeats(self, disable_init: bool) -> AsyncGenerator[None, None]: + async def heartbeats(self, wait_for_mem_snap: bool) -> AsyncGenerator[None, None]: async with TaskContext() as tc: self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop()) t.set_name("heartbeat loop") - self._waiting_for_memory_snapshot = disable_init + self._waiting_for_memory_snapshot = wait_for_mem_snap try: yield finally: @@ -575,10 +577,11 @@ async def memory_restore(self) -> None: await asyncio.sleep(0.01) continue - # Turn heartbeats back on - async with self._pause_heartbeats: + # Turn heartbeats back on. It is safe to do this here since the Snapshot RPC + # is certainly finished at this point. + async with self._heartbeat_condition: self._waiting_for_memory_snapshot = False - self._pause_heartbeats.notify_all() + self._heartbeat_condition.notify_all() logger.debug("Container: restored") @@ -630,9 +633,9 @@ async def memory_snapshot(self) -> None: logger.debug(f"Checkpoint ID: {self.checkpoint_id} (Memory Snapshot ID)") # Pause heartbeats since they keep the client connection open which causes the snapshotter to crash - async with self._pause_heartbeats: + async with self._heartbeat_condition: self._waiting_for_memory_snapshot = True - self._pause_heartbeats.notify_all() + self._heartbeat_condition.notify_all() await self._client.stub.ContainerCheckpoint( api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) diff --git a/test/container_app_test.py b/test/container_app_test.py index 7d9a86b32..15855e9d1 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -19,6 +19,21 @@ def my_f_1(x): pass +def temp_restore_path(tmpdir): + # Write out a restore file so that snapshot+restore will complete + restore_path = tmpdir.join("fake-restore-state.json") + restore_path.write_text( + json.dumps( + { + "task_id": "ta-i-am-restored", + "task_secret": "ts-i-am-restored", + "function_id": "fu-i-am-restored", + } + ), + encoding="utf-8", + ) + return restore_path + @pytest.mark.asyncio async def test_container_function_lazily_imported(container_client): @@ -47,18 +62,7 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer): # Get a reference to a Client instance in memory old_client = container_client io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client) - restore_path = tmpdir.join("fake-restore-state.json") - # Write out a restore file so that snapshot+restore will complete - restore_path.write_text( - json.dumps( - { - "task_id": "ta-i-am-restored", - "task_secret": "ts-i-am-restored", - "function_id": "fu-i-am-restored", - } - ), - encoding="utf-8", - ) + restore_path = temp_restore_path(tmpdir) with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} ): @@ -72,31 +76,22 @@ async def test_container_snapshot_restore_heartbeats(tmpdir, servicer): client = _Client(servicer.container_addr, api_pb2.CLIENT_TYPE_CONTAINER, ("ta-123", "task-secret")) async with client as async_client: io_manager = _ContainerIOManager(api_pb2.ContainerArguments(), async_client) - restore_path = tmpdir.join("fake-restore-state.json") - restore_path.write_text( - json.dumps( - { - "task_id": "ta-i-am-restored", - "task_secret": "ts-i-am-restored", - "function_id": "fu-i-am-restored", - } - ), - encoding="utf-8", - ) + restore_path = temp_restore_path(tmpdir) - # Ensure that heartbeats do not run before the snapshot - # Ensure that heartbeats do run after the snapshot + # Ensure that heartbeats only run after the snapshot + heartbeat_interval_secs = 0.01 async with io_manager.heartbeats(True): with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}, ): - with mock.patch("modal.runner.HEARTBEAT_INTERVAL", 0.01): + with mock.patch("modal.runner.HEARTBEAT_INTERVAL", heartbeat_interval_secs): + await asyncio.sleep(heartbeat_interval_secs*2) assert not list( filter(lambda req: isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests) ) await io_manager.memory_snapshot() - await asyncio.sleep(0.01) + await asyncio.sleep(heartbeat_interval_secs*2) assert list( filter(lambda req: isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests) ) @@ -158,7 +153,6 @@ def weird_torch_module(): @pytest.mark.asyncio async def test_container_snapshot_patching(fake_torch_module, container_client, tmpdir, servicer): io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client) - restore_path = tmpdir.join("fake-restore-state.json") # bring fake torch into scope and call the utility fn import torch @@ -166,16 +160,7 @@ async def test_container_snapshot_patching(fake_torch_module, container_client, assert torch.cuda.device_count() == 0 # Write out a restore file so that snapshot+restore will complete - restore_path.write_text( - json.dumps( - { - "task_id": "ta-i-am-restored", - "task_secret": "ts-i-am-restored", - "function_id": "fu-i-am-restored", - } - ), - encoding="utf-8", - ) + restore_path = temp_restore_path(tmpdir) with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} ): @@ -188,24 +173,13 @@ async def test_container_snapshot_patching(fake_torch_module, container_client, @pytest.mark.asyncio async def test_container_snapshot_patching_err(weird_torch_module, container_client, tmpdir, servicer): io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client) - restore_path = tmpdir.join("fake-restore-state.json") + restore_path = temp_restore_path(tmpdir) # bring weird torch into scope and call the utility fn import torch as trch assert trch.IM_WEIRD == 42 - # Write out a restore file so that snapshot+restore will complete - restore_path.write_text( - json.dumps( - { - "task_id": "ta-i-am-restored", - "task_secret": "ts-i-am-restored", - "function_id": "fu-i-am-restored", - } - ), - encoding="utf-8", - ) with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} ): From c0f5fcbcd934514ca15c9973f81cb4c2794187a1 Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Tue, 16 Jul 2024 16:17:03 +0000 Subject: [PATCH 22/22] Run entire restore phase within the lock --- modal/_container_io_manager.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 9c7c3c8cb..ded25a955 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -155,8 +155,10 @@ async def _heartbeat_handle_cancellations(self) -> bool: request.current_input_started_at = self.current_input_started_at async with self._heartbeat_condition: - # Continuously wait until `waiting_for_memory_snapshot` is false. More efficient - # than a busy-wait since `.wait()` yields control + # Continuously wait until `waiting_for_memory_snapshot` is false. + # TODO(matt): Verify that a `while` is necessary over an `if`. Spurious + # wakeups could allow execution to continue despite `_waiting_for_memory_snapshot` + # being true. while self._waiting_for_memory_snapshot: await self._heartbeat_condition.wait() @@ -577,12 +579,6 @@ async def memory_restore(self) -> None: await asyncio.sleep(0.01) continue - # Turn heartbeats back on. It is safe to do this here since the Snapshot RPC - # is certainly finished at this point. - async with self._heartbeat_condition: - self._waiting_for_memory_snapshot = False - self._heartbeat_condition.notify_all() - logger.debug("Container: restored") # Look for state file and create new client with updated credentials. @@ -634,17 +630,24 @@ async def memory_snapshot(self) -> None: # Pause heartbeats since they keep the client connection open which causes the snapshotter to crash async with self._heartbeat_condition: + # Notify the heartbeat loop that the snapshot phase has begun in order to + # prevent it from sending heartbeat RPCs self._waiting_for_memory_snapshot = True self._heartbeat_condition.notify_all() - await self._client.stub.ContainerCheckpoint( - api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) - ) + await self._client.stub.ContainerCheckpoint( + api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) + ) + + await self._client._close(forget_credentials=True) - await self._client._close(forget_credentials=True) + logger.debug("Memory snapshot request sent. Connection closed.") + await self.memory_restore() - logger.debug("Memory snapshot request sent. Connection closed.") - await self.memory_restore() + # Turn heartbeats back on. This is safe since the snapshot RPC + # and the restore phase has finished. + self._waiting_for_memory_snapshot = False + self._heartbeat_condition.notify_all() async def volume_commit(self, volume_ids: List[str]) -> None: """