Skip to content

Commit

Permalink
[MOD-3251] Stop heartbeats before sending the container snapshot RPC (#…
Browse files Browse the repository at this point in the history
…2004)

* Stop heartbeats before snapshotting

* Use asyncio.Condition to ensure mutual exclusion

* Use async client

* Use servicer

* Run entire restore phase within the lock

---------

Co-authored-by: Jonathon Belotti <jonathon@modal.com>
  • Loading branch information
mattnappo and thundergolfer committed Jul 16, 2024
1 parent 948133f commit aebfc72
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 60 deletions.
3 changes: 2 additions & 1 deletion modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):

_client: _Client = synchronizer._translate_in(client) # TODO(erikbern): ugly

with container_io_manager.heartbeats(), UserCodeEventLoop() as event_loop:
with container_io_manager.heartbeats(function_def.is_checkpointing_function), UserCodeEventLoop() as event_loop:
# If this is a serialized function, fetch the definition from the server
if function_def.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED:
ser_cls, ser_fun = container_io_manager.get_serialized_function()
Expand Down Expand Up @@ -779,6 +779,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
if function_def.is_checkpointing_function:
container_io_manager.memory_snapshot()


# Install hooks for interactive functions.
if function_def.pty_info.pty_type != api_pb2.PTYInfo.PTY_TYPE_UNSPECIFIED:

Expand Down
57 changes: 36 additions & 21 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ class _ContainerIOManager:
_input_concurrency: Optional[int]
_semaphore: Optional[asyncio.Semaphore]
_environment_name: str
_waiting_for_memory_snapshot: bool
_heartbeat_loop: Optional[asyncio.Task]
_heartbeat_condition: asyncio.Condition
_waiting_for_memory_snapshot: bool

_is_interactivity_enabled: bool
_fetching_inputs: bool
Expand All @@ -101,8 +102,9 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client):

self._semaphore = None
self._environment_name = container_args.environment_name
self._waiting_for_memory_snapshot = False
self._heartbeat_loop = None
self._heartbeat_condition = asyncio.Condition()
self._waiting_for_memory_snapshot = False

self._is_interactivity_enabled = False
self._fetching_inputs = True
Expand Down Expand Up @@ -146,22 +148,24 @@ async def _heartbeat_handle_cancellations(self) -> bool:
# Return True if a cancellation event was received, in that case
# we shouldn't wait too long for another heartbeat

# Don't send heartbeats for tasks waiting to be checkpointed.
# Calling gRPC methods open new connections which block the
# checkpointing process.
if self._waiting_for_memory_snapshot:
return False

request = api_pb2.ContainerHeartbeatRequest(supports_graceful_input_cancellation=True)
if self.current_input_id is not None:
request.current_input_id = self.current_input_id
if self.current_input_started_at is not None:
request.current_input_started_at = self.current_input_started_at

# TODO(erikbern): capture exceptions?
response = await retry_transient_errors(
self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
)
async with self._heartbeat_condition:
# Continuously wait until `waiting_for_memory_snapshot` is false.
# TODO(matt): Verify that a `while` is necessary over an `if`. Spurious
# wakeups could allow execution to continue despite `_waiting_for_memory_snapshot`
# being true.
while self._waiting_for_memory_snapshot:
await self._heartbeat_condition.wait()

# TODO(erikbern): capture exceptions?
response = await retry_transient_errors(
self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
)

if response.HasField("cancel_input_event"):
# Pause processing of the current input by signaling self a SIGUSR1.
Expand Down Expand Up @@ -196,10 +200,11 @@ async def _heartbeat_handle_cancellations(self) -> bool:
return False

@asynccontextmanager
async def heartbeats(self) -> AsyncGenerator[None, None]:
async def heartbeats(self, wait_for_mem_snap: bool) -> AsyncGenerator[None, None]:
async with TaskContext() as tc:
self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop())
t.set_name("heartbeat loop")
self._waiting_for_memory_snapshot = wait_for_mem_snap
try:
yield
finally:
Expand Down Expand Up @@ -617,22 +622,32 @@ async def memory_restore(self) -> None:
)

self._client = await _Client.from_env()
self._waiting_for_memory_snapshot = False

async def memory_snapshot(self) -> None:
"""Message server indicating that function is ready to be checkpointed."""
if self.checkpoint_id:
logger.debug(f"Checkpoint ID: {self.checkpoint_id} (Memory Snapshot ID)")

await self._client.stub.ContainerCheckpoint(
api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id)
)
# Pause heartbeats since they keep the client connection open which causes the snapshotter to crash
async with self._heartbeat_condition:
# Notify the heartbeat loop that the snapshot phase has begun in order to
# prevent it from sending heartbeat RPCs
self._waiting_for_memory_snapshot = True
self._heartbeat_condition.notify_all()

await self._client.stub.ContainerCheckpoint(
api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id)
)

await self._client._close(forget_credentials=True)

self._waiting_for_memory_snapshot = True
await self._client._close(forget_credentials=True)
logger.debug("Memory snapshot request sent. Connection closed.")
await self.memory_restore()

logger.debug("Memory snapshot request sent. Connection closed.")
await self.memory_restore()
# Turn heartbeats back on. This is safe since the snapshot RPC
# and the restore phase has finished.
self._waiting_for_memory_snapshot = False
self._heartbeat_condition.notify_all()

async def volume_commit(self, volume_ids: List[str]) -> None:
"""
Expand Down
83 changes: 47 additions & 36 deletions test/container_app_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright Modal Labs 2022
import asyncio
import json
import os
import pytest
Expand All @@ -9,14 +10,30 @@
from google.protobuf.message import Message

from modal import App, interact
from modal._container_io_manager import ContainerIOManager
from modal._container_io_manager import ContainerIOManager, _ContainerIOManager
from modal.client import _Client
from modal.running_app import RunningApp
from modal_proto import api_pb2


def my_f_1(x):
pass

def temp_restore_path(tmpdir):
# Write out a restore file so that snapshot+restore will complete
restore_path = tmpdir.join("fake-restore-state.json")
restore_path.write_text(
json.dumps(
{
"task_id": "ta-i-am-restored",
"task_secret": "ts-i-am-restored",
"function_id": "fu-i-am-restored",
}
),
encoding="utf-8",
)
return restore_path


@pytest.mark.asyncio
async def test_container_function_lazily_imported(container_client):
Expand Down Expand Up @@ -45,18 +62,7 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer):
# Get a reference to a Client instance in memory
old_client = container_client
io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client)
restore_path = tmpdir.join("fake-restore-state.json")
# Write out a restore file so that snapshot+restore will complete
restore_path.write_text(
json.dumps(
{
"task_id": "ta-i-am-restored",
"task_secret": "ts-i-am-restored",
"function_id": "fu-i-am-restored",
}
),
encoding="utf-8",
)
restore_path = temp_restore_path(tmpdir)
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
Expand All @@ -65,6 +71,32 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer):
assert old_client.credentials == ("ta-i-am-restored", "ts-i-am-restored")


@pytest.mark.asyncio
async def test_container_snapshot_restore_heartbeats(tmpdir, servicer):
client = _Client(servicer.container_addr, api_pb2.CLIENT_TYPE_CONTAINER, ("ta-123", "task-secret"))
async with client as async_client:
io_manager = _ContainerIOManager(api_pb2.ContainerArguments(), async_client)
restore_path = temp_restore_path(tmpdir)

# Ensure that heartbeats only run after the snapshot
heartbeat_interval_secs = 0.01
async with io_manager.heartbeats(True):
with mock.patch.dict(
os.environ,
{"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr},
):
with mock.patch("modal.runner.HEARTBEAT_INTERVAL", heartbeat_interval_secs):
await asyncio.sleep(heartbeat_interval_secs*2)
assert not list(
filter(lambda req: isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests)
)
await io_manager.memory_snapshot()
await asyncio.sleep(heartbeat_interval_secs*2)
assert list(
filter(lambda req: isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests)
)


@pytest.mark.asyncio
async def test_container_debug_snapshot(container_client, tmpdir, servicer):
# Get an IO manager, where restore takes place
Expand Down Expand Up @@ -121,24 +153,14 @@ def weird_torch_module():
@pytest.mark.asyncio
async def test_container_snapshot_patching(fake_torch_module, container_client, tmpdir, servicer):
io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client)
restore_path = tmpdir.join("fake-restore-state.json")

# bring fake torch into scope and call the utility fn
import torch

assert torch.cuda.device_count() == 0

# Write out a restore file so that snapshot+restore will complete
restore_path.write_text(
json.dumps(
{
"task_id": "ta-i-am-restored",
"task_secret": "ts-i-am-restored",
"function_id": "fu-i-am-restored",
}
),
encoding="utf-8",
)
restore_path = temp_restore_path(tmpdir)
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
Expand All @@ -151,24 +173,13 @@ async def test_container_snapshot_patching(fake_torch_module, container_client,
@pytest.mark.asyncio
async def test_container_snapshot_patching_err(weird_torch_module, container_client, tmpdir, servicer):
io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client)
restore_path = tmpdir.join("fake-restore-state.json")
restore_path = temp_restore_path(tmpdir)

# bring weird torch into scope and call the utility fn
import torch as trch

assert trch.IM_WEIRD == 42

# Write out a restore file so that snapshot+restore will complete
restore_path.write_text(
json.dumps(
{
"task_id": "ta-i-am-restored",
"task_secret": "ts-i-am-restored",
"function_id": "fu-i-am-restored",
}
),
encoding="utf-8",
)
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
Expand Down
5 changes: 4 additions & 1 deletion test/container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ def test_checkpointing_cls_function(servicer):
ret = _run_container(
servicer,
"test.supports.functions",
"CheckpointingCls.*",
"SnapshottingCls.*",
inputs=_get_inputs((("D",), {}), method_name="f"),
is_checkpointing_function=True,
is_class=True,
Expand Down Expand Up @@ -819,6 +819,9 @@ def test_container_heartbeats(servicer):
_run_container(servicer, "test.supports.functions", "square")
assert any(isinstance(request, api_pb2.ContainerHeartbeatRequest) for request in servicer.requests)

_run_container(servicer, "test.supports.functions", "snapshotting_square")
assert any(isinstance(request, api_pb2.ContainerHeartbeatRequest) for request in servicer.requests)


@skip_github_non_linux
def test_cli(servicer):
Expand Down
7 changes: 6 additions & 1 deletion test/supports/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def f(self, x):


@app.cls(enable_memory_snapshot=True)
class CheckpointingCls:
class SnapshottingCls:
def __init__(self):
self._vals = []

Expand All @@ -414,6 +414,11 @@ def f(self, x):
return "".join(self._vals) + x


@app.function(enable_memory_snapshot=True)
def snapshotting_square(x):
return x * x


@app.cls()
class EventLoopCls:
@enter()
Expand Down

0 comments on commit aebfc72

Please sign in to comment.