Skip to content

Commit

Permalink
Fixed unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mattnappo committed Jul 12, 2024
1 parent 626af1c commit 7228652
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
4 changes: 1 addition & 3 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,10 @@ async def _heartbeat_handle_cancellations(self) -> bool:

async with self._pause_heartbeats:
await self._pause_heartbeats.wait()

# TODO(erikbern): capture exceptions?
response = await retry_transient_errors(
self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
)
print("SENT HEARTBEAT")

if response.HasField("cancel_input_event"):
# Pause processing of the current input by signaling self a SIGUSR1.
Expand Down Expand Up @@ -631,7 +629,7 @@ async def memory_snapshot(self) -> None:
if self.checkpoint_id:
logger.debug(f"Checkpoint ID: {self.checkpoint_id} (Memory Snapshot ID)")

# Pause heartbeats, since they keep the client connection open, causing the snapshotter to crash.
# Pause heartbeats since they keep the client connection open which causes the snapshotter to crash
async with self._pause_heartbeats:
await self._client.stub.ContainerCheckpoint(
api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id)
Expand Down
24 changes: 14 additions & 10 deletions test/container_app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,12 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer):
encoding="utf-8",
)
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr},
):
io_manager.memory_snapshot()
# In-memory Client instance should have update credentials, not old credentials
assert old_client.credentials == ("ta-i-am-restored", "ts-i-am-restored")
with io_manager.heartbeats():
io_manager.memory_snapshot()
# In-memory Client instance should have update credentials, not old credentials
assert old_client.credentials == ("ta-i-am-restored", "ts-i-am-restored")


@pytest.mark.asyncio
Expand All @@ -82,8 +83,9 @@ async def test_container_debug_snapshot(container_client, tmpdir, servicer):
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
io_manager.memory_snapshot()
test_breakpoint.assert_called_once()
with io_manager.heartbeats():
io_manager.memory_snapshot()
test_breakpoint.assert_called_once()


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -142,10 +144,11 @@ async def test_container_snapshot_patching(fake_torch_module, container_client,
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
io_manager.memory_snapshot()
import torch
with io_manager.heartbeats():
io_manager.memory_snapshot()
import torch

assert torch.cuda.device_count() == 2
assert torch.cuda.device_count() == 2


@pytest.mark.asyncio
Expand All @@ -172,7 +175,8 @@ async def test_container_snapshot_patching_err(weird_torch_module, container_cli
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
io_manager.memory_snapshot() # should not crash
with io_manager.heartbeats():
io_manager.memory_snapshot() # should not crash


def test_interact(container_client, servicer):
Expand Down

0 comments on commit 7228652

Please sign in to comment.