From 626af1c587865612d652a41a04fb1e2d0d9f4e5f Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Fri, 12 Jul 2024 18:24:02 +0000 Subject: [PATCH] Use asyncio.Condition to ensure mutual exclusion --- modal/_container_io_manager.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 2048dc225..71f4473ff 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -75,7 +75,7 @@ class _ContainerIOManager: _environment_name: str _waiting_for_memory_snapshot: bool _heartbeat_loop: Optional[asyncio.Task] - _pause_heartbeats: Optional[asyncio.Event] + _pause_heartbeats: Optional[asyncio.Condition] _is_interactivity_enabled: bool _fetching_inputs: bool @@ -160,13 +160,14 @@ async def _heartbeat_handle_cancellations(self) -> bool: if self.current_input_started_at is not None: request.current_input_started_at = self.current_input_started_at - # Wait until memory snapshotting finishes - await self._pause_heartbeats.wait() + async with self._pause_heartbeats: + await self._pause_heartbeats.wait() - # TODO(erikbern): capture exceptions? - response = await retry_transient_errors( - self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT - ) + # TODO(erikbern): capture exceptions? + response = await retry_transient_errors( + self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT + ) + print("SENT HEARTBEAT") if response.HasField("cancel_input_event"): # Pause processing of the current input by signaling self a SIGUSR1. @@ -204,7 +205,7 @@ async def _heartbeat_handle_cancellations(self) -> bool: async def heartbeats(self) -> AsyncGenerator[None, None]: async with TaskContext() as tc: self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop()) - self._pause_heartbeats = asyncio.Event() + self._pause_heartbeats = asyncio.Condition() t.set_name("heartbeat loop") try: yield @@ -579,7 +580,6 @@ async def memory_restore(self) -> None: logger.debug(f"Waiting for restore (elapsed={time.perf_counter() - start:.3f}s)") await asyncio.sleep(0.01) continue - self._pause_heartbeats.set() # Resume heartbeats logger.debug("Container: restored") @@ -632,10 +632,10 @@ async def memory_snapshot(self) -> None: logger.debug(f"Checkpoint ID: {self.checkpoint_id} (Memory Snapshot ID)") # Pause heartbeats, since they keep the client connection open, causing the snapshotter to crash. - self._pause_heartbeats.clear() - await self._client.stub.ContainerCheckpoint( - api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) - ) + async with self._pause_heartbeats: + await self._client.stub.ContainerCheckpoint( + api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) + ) self._waiting_for_memory_snapshot = True await self._client._close(forget_credentials=True)