Skip to content

Commit

Permalink
Use asyncio.Condition to ensure mutual exclusion
Browse files Browse the repository at this point in the history
  • Loading branch information
mattnappo committed Jul 12, 2024
1 parent 2cdcd60 commit 626af1c
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class _ContainerIOManager:
_environment_name: str
_waiting_for_memory_snapshot: bool
_heartbeat_loop: Optional[asyncio.Task]
_pause_heartbeats: Optional[asyncio.Event]
_pause_heartbeats: Optional[asyncio.Condition]

_is_interactivity_enabled: bool
_fetching_inputs: bool
Expand Down Expand Up @@ -160,13 +160,14 @@ async def _heartbeat_handle_cancellations(self) -> bool:
if self.current_input_started_at is not None:
request.current_input_started_at = self.current_input_started_at

# Wait until memory snapshotting finishes
await self._pause_heartbeats.wait()
async with self._pause_heartbeats:
await self._pause_heartbeats.wait()

# TODO(erikbern): capture exceptions?
response = await retry_transient_errors(
self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
)
# TODO(erikbern): capture exceptions?
response = await retry_transient_errors(
self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
)
print("SENT HEARTBEAT")

if response.HasField("cancel_input_event"):
# Pause processing of the current input by signaling self a SIGUSR1.
Expand Down Expand Up @@ -204,7 +205,7 @@ async def _heartbeat_handle_cancellations(self) -> bool:
async def heartbeats(self) -> AsyncGenerator[None, None]:
async with TaskContext() as tc:
self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop())
self._pause_heartbeats = asyncio.Event()
self._pause_heartbeats = asyncio.Condition()
t.set_name("heartbeat loop")
try:
yield
Expand Down Expand Up @@ -579,7 +580,6 @@ async def memory_restore(self) -> None:
logger.debug(f"Waiting for restore (elapsed={time.perf_counter() - start:.3f}s)")
await asyncio.sleep(0.01)
continue
self._pause_heartbeats.set() # Resume heartbeats

logger.debug("Container: restored")

Expand Down Expand Up @@ -632,10 +632,10 @@ async def memory_snapshot(self) -> None:
logger.debug(f"Checkpoint ID: {self.checkpoint_id} (Memory Snapshot ID)")

# Pause heartbeats, since they keep the client connection open, causing the snapshotter to crash.
self._pause_heartbeats.clear()
await self._client.stub.ContainerCheckpoint(
api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id)
)
async with self._pause_heartbeats:
await self._client.stub.ContainerCheckpoint(
api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id)
)

self._waiting_for_memory_snapshot = True
await self._client._close(forget_credentials=True)
Expand Down

0 comments on commit 626af1c

Please sign in to comment.