From 72286520563faf4e078c45457589b0b71edb7550 Mon Sep 17 00:00:00 2001 From: Matt Nappo Date: Fri, 12 Jul 2024 19:05:02 +0000 Subject: [PATCH] Fixed unit tests --- modal/_container_io_manager.py | 4 +--- test/container_app_test.py | 24 ++++++++++++++---------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index 71f4473ff..08096f2ee 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -162,12 +162,10 @@ async def _heartbeat_handle_cancellations(self) -> bool: async with self._pause_heartbeats: await self._pause_heartbeats.wait() - # TODO(erikbern): capture exceptions? response = await retry_transient_errors( self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT ) - print("SENT HEARTBEAT") if response.HasField("cancel_input_event"): # Pause processing of the current input by signaling self a SIGUSR1. @@ -631,7 +629,7 @@ async def memory_snapshot(self) -> None: if self.checkpoint_id: logger.debug(f"Checkpoint ID: {self.checkpoint_id} (Memory Snapshot ID)") - # Pause heartbeats, since they keep the client connection open, causing the snapshotter to crash. + # Pause heartbeats since they keep the client connection open which causes the snapshotter to crash async with self._pause_heartbeats: await self._client.stub.ContainerCheckpoint( api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id) diff --git a/test/container_app_test.py b/test/container_app_test.py index 90f753e02..345686221 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -58,11 +58,12 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer): encoding="utf-8", ) with mock.patch.dict( - os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} + os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}, ): - io_manager.memory_snapshot() - # In-memory Client instance should have update credentials, not old credentials - assert old_client.credentials == ("ta-i-am-restored", "ts-i-am-restored") + with io_manager.heartbeats(): + io_manager.memory_snapshot() + # In-memory Client instance should have update credentials, not old credentials + assert old_client.credentials == ("ta-i-am-restored", "ts-i-am-restored") @pytest.mark.asyncio @@ -82,8 +83,9 @@ async def test_container_debug_snapshot(container_client, tmpdir, servicer): with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} ): - io_manager.memory_snapshot() - test_breakpoint.assert_called_once() + with io_manager.heartbeats(): + io_manager.memory_snapshot() + test_breakpoint.assert_called_once() @pytest.fixture(scope="function") @@ -142,10 +144,11 @@ async def test_container_snapshot_patching(fake_torch_module, container_client, with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} ): - io_manager.memory_snapshot() - import torch + with io_manager.heartbeats(): + io_manager.memory_snapshot() + import torch - assert torch.cuda.device_count() == 2 + assert torch.cuda.device_count() == 2 @pytest.mark.asyncio @@ -172,7 +175,8 @@ async def test_container_snapshot_patching_err(weird_torch_module, container_cli with mock.patch.dict( os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} ): - io_manager.memory_snapshot() # should not crash + with io_manager.heartbeats(): + io_manager.memory_snapshot() # should not crash def test_interact(container_client, servicer):