Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MOD-3251] Stop heartbeats before sending the container snapshot RPC #2004

Merged
merged 28 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f3c04c6
Stop heartbeats before snapshotting
mattnappo Jul 12, 2024
41adde8
Merge branch 'main' into matt/gvisor-flake-fix
thundergolfer Jul 12, 2024
4cceb72
Added asyncio.Event for pausing heartbeats
mattnappo Jul 12, 2024
2cdcd60
Merge branch 'matt/gvisor-flake-fix' of github.com:modal-labs/modal-c…
mattnappo Jul 12, 2024
626af1c
Use asyncio.Condition to ensure mutual exclusion
mattnappo Jul 12, 2024
7228652
Fixed unit tests
mattnappo Jul 12, 2024
b0bef65
Merge branch 'main' of github.com:modal-labs/modal-client into matt/g…
mattnappo Jul 15, 2024
86b8fd5
Fixed bug, but now its slow
mattnappo Jul 15, 2024
4c6ed72
Fixed bottleneck
mattnappo Jul 15, 2024
096a399
Remove old logic
mattnappo Jul 15, 2024
bf5400b
Renamed cond var
mattnappo Jul 15, 2024
484e24d
Wrote tests
mattnappo Jul 15, 2024
b6fa509
Wrote better tests
mattnappo Jul 15, 2024
6071891
Try to fix tests
mattnappo Jul 15, 2024
70c49ff
Retry
mattnappo Jul 15, 2024
5b423f5
Undo warning
mattnappo Jul 15, 2024
817709f
Await
mattnappo Jul 15, 2024
3259b2b
Try using async client
mattnappo Jul 15, 2024
3c3209e
Use servicer
mattnappo Jul 15, 2024
a9070ff
Use servicer
mattnappo Jul 15, 2024
153eb0a
Lint
mattnappo Jul 15, 2024
5bb4044
Add sleep in test
mattnappo Jul 15, 2024
91b5665
Reduce sleep time
mattnappo Jul 15, 2024
7e7d692
Merge branch 'main' into matt/gvisor-flake-fix
thundergolfer Jul 16, 2024
74ae84a
Address review
mattnappo Jul 16, 2024
7bbae23
Merge branch 'matt/gvisor-flake-fix' of github.com:modal-labs/modal-c…
mattnappo Jul 16, 2024
c0f5fcb
Run entire restore phase within the lock
mattnappo Jul 16, 2024
87a9036
Merge branch 'main' into matt/gvisor-flake-fix
thundergolfer Jul 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):

_client: _Client = synchronizer._translate_in(client) # TODO(erikbern): ugly

with container_io_manager.heartbeats(), UserCodeEventLoop() as event_loop:
with container_io_manager.heartbeats(function_def.is_checkpointing_function), UserCodeEventLoop() as event_loop:
mattnappo marked this conversation as resolved.
Show resolved Hide resolved
# If this is a serialized function, fetch the definition from the server
if function_def.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED:
ser_cls, ser_fun = container_io_manager.get_serialized_function()
Expand Down Expand Up @@ -779,6 +779,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
if function_def.is_checkpointing_function:
container_io_manager.memory_snapshot()


# Install hooks for interactive functions.
if function_def.pty_info.pty_type != api_pb2.PTYInfo.PTY_TYPE_UNSPECIFIED:

Expand Down
57 changes: 36 additions & 21 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ class _ContainerIOManager:
_input_concurrency: Optional[int]
_semaphore: Optional[asyncio.Semaphore]
_environment_name: str
_waiting_for_memory_snapshot: bool
_heartbeat_loop: Optional[asyncio.Task]
_heartbeat_condition: asyncio.Condition
_waiting_for_memory_snapshot: bool

_is_interactivity_enabled: bool
_fetching_inputs: bool
Expand All @@ -101,8 +102,9 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client):

self._semaphore = None
self._environment_name = container_args.environment_name
self._waiting_for_memory_snapshot = False
self._heartbeat_loop = None
self._heartbeat_condition = asyncio.Condition()
self._waiting_for_memory_snapshot = False

self._is_interactivity_enabled = False
self._fetching_inputs = True
Expand Down Expand Up @@ -146,22 +148,24 @@ async def _heartbeat_handle_cancellations(self) -> bool:
# Return True if a cancellation event was received, in that case
# we shouldn't wait too long for another heartbeat

# Don't send heartbeats for tasks waiting to be checkpointed.
# Calling gRPC methods open new connections which block the
# checkpointing process.
if self._waiting_for_memory_snapshot:
return False

request = api_pb2.ContainerHeartbeatRequest(supports_graceful_input_cancellation=True)
if self.current_input_id is not None:
request.current_input_id = self.current_input_id
if self.current_input_started_at is not None:
request.current_input_started_at = self.current_input_started_at

# TODO(erikbern): capture exceptions?
response = await retry_transient_errors(
self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
)
async with self._heartbeat_condition:
# Continuously wait until `waiting_for_memory_snapshot` is false.
# TODO(matt): Verify that a `while` is necessary over an `if`. Spurious
# wakeups could allow execution to continue despite `_waiting_for_memory_snapshot`
# being true.
while self._waiting_for_memory_snapshot:
thundergolfer marked this conversation as resolved.
Show resolved Hide resolved
await self._heartbeat_condition.wait()

# TODO(erikbern): capture exceptions?
response = await retry_transient_errors(
self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
)

if response.HasField("cancel_input_event"):
# Pause processing of the current input by signaling self a SIGUSR1.
Expand Down Expand Up @@ -196,10 +200,11 @@ async def _heartbeat_handle_cancellations(self) -> bool:
return False

@asynccontextmanager
async def heartbeats(self) -> AsyncGenerator[None, None]:
async def heartbeats(self, wait_for_mem_snap: bool) -> AsyncGenerator[None, None]:
async with TaskContext() as tc:
self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop())
t.set_name("heartbeat loop")
self._waiting_for_memory_snapshot = wait_for_mem_snap
try:
yield
finally:
Expand Down Expand Up @@ -617,22 +622,32 @@ async def memory_restore(self) -> None:
)

self._client = await _Client.from_env()
self._waiting_for_memory_snapshot = False

async def memory_snapshot(self) -> None:
"""Message server indicating that function is ready to be checkpointed."""
if self.checkpoint_id:
logger.debug(f"Checkpoint ID: {self.checkpoint_id} (Memory Snapshot ID)")

await self._client.stub.ContainerCheckpoint(
api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id)
)
# Pause heartbeats since they keep the client connection open which causes the snapshotter to crash
async with self._heartbeat_condition:
# Notify the heartbeat loop that the snapshot phase has begun in order to
# prevent it from sending heartbeat RPCs
self._waiting_for_memory_snapshot = True
self._heartbeat_condition.notify_all()

await self._client.stub.ContainerCheckpoint(
api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id)
)

await self._client._close(forget_credentials=True)

self._waiting_for_memory_snapshot = True
await self._client._close(forget_credentials=True)
logger.debug("Memory snapshot request sent. Connection closed.")
await self.memory_restore()

logger.debug("Memory snapshot request sent. Connection closed.")
await self.memory_restore()
# Turn heartbeats back on. This is safe since the snapshot RPC
# and the restore phase has finished.
self._waiting_for_memory_snapshot = False
self._heartbeat_condition.notify_all()

async def volume_commit(self, volume_ids: List[str]) -> None:
"""
Expand Down
83 changes: 47 additions & 36 deletions test/container_app_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright Modal Labs 2022
import asyncio
import json
import os
import pytest
Expand All @@ -9,14 +10,30 @@
from google.protobuf.message import Message

from modal import App, interact
from modal._container_io_manager import ContainerIOManager
from modal._container_io_manager import ContainerIOManager, _ContainerIOManager
from modal.client import _Client
from modal.running_app import RunningApp
from modal_proto import api_pb2


def my_f_1(x):
pass

def temp_restore_path(tmpdir):
# Write out a restore file so that snapshot+restore will complete
restore_path = tmpdir.join("fake-restore-state.json")
restore_path.write_text(
json.dumps(
{
"task_id": "ta-i-am-restored",
"task_secret": "ts-i-am-restored",
"function_id": "fu-i-am-restored",
}
),
encoding="utf-8",
)
return restore_path


@pytest.mark.asyncio
async def test_container_function_lazily_imported(container_client):
Expand Down Expand Up @@ -45,18 +62,7 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer):
# Get a reference to a Client instance in memory
old_client = container_client
io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client)
restore_path = tmpdir.join("fake-restore-state.json")
# Write out a restore file so that snapshot+restore will complete
restore_path.write_text(
json.dumps(
{
"task_id": "ta-i-am-restored",
"task_secret": "ts-i-am-restored",
"function_id": "fu-i-am-restored",
}
),
encoding="utf-8",
)
restore_path = temp_restore_path(tmpdir)
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
Expand All @@ -65,6 +71,32 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer):
assert old_client.credentials == ("ta-i-am-restored", "ts-i-am-restored")


@pytest.mark.asyncio
async def test_container_snapshot_restore_heartbeats(tmpdir, servicer):
client = _Client(servicer.container_addr, api_pb2.CLIENT_TYPE_CONTAINER, ("ta-123", "task-secret"))
async with client as async_client:
io_manager = _ContainerIOManager(api_pb2.ContainerArguments(), async_client)
restore_path = temp_restore_path(tmpdir)

# Ensure that heartbeats only run after the snapshot
heartbeat_interval_secs = 0.01
async with io_manager.heartbeats(True):
with mock.patch.dict(
os.environ,
{"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr},
):
with mock.patch("modal.runner.HEARTBEAT_INTERVAL", heartbeat_interval_secs):
await asyncio.sleep(heartbeat_interval_secs*2)
assert not list(
filter(lambda req: isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests)
)
await io_manager.memory_snapshot()
await asyncio.sleep(heartbeat_interval_secs*2)
assert list(
filter(lambda req: isinstance(req, api_pb2.ContainerHeartbeatRequest), servicer.requests)
)


@pytest.mark.asyncio
async def test_container_debug_snapshot(container_client, tmpdir, servicer):
# Get an IO manager, where restore takes place
Expand Down Expand Up @@ -121,24 +153,14 @@ def weird_torch_module():
@pytest.mark.asyncio
async def test_container_snapshot_patching(fake_torch_module, container_client, tmpdir, servicer):
io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client)
restore_path = tmpdir.join("fake-restore-state.json")

# bring fake torch into scope and call the utility fn
import torch

assert torch.cuda.device_count() == 0

# Write out a restore file so that snapshot+restore will complete
restore_path.write_text(
json.dumps(
{
"task_id": "ta-i-am-restored",
"task_secret": "ts-i-am-restored",
"function_id": "fu-i-am-restored",
}
),
encoding="utf-8",
)
restore_path = temp_restore_path(tmpdir)
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
Expand All @@ -151,24 +173,13 @@ async def test_container_snapshot_patching(fake_torch_module, container_client,
@pytest.mark.asyncio
async def test_container_snapshot_patching_err(weird_torch_module, container_client, tmpdir, servicer):
io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client)
restore_path = tmpdir.join("fake-restore-state.json")
restore_path = temp_restore_path(tmpdir)

# bring weird torch into scope and call the utility fn
import torch as trch

assert trch.IM_WEIRD == 42

# Write out a restore file so that snapshot+restore will complete
restore_path.write_text(
json.dumps(
{
"task_id": "ta-i-am-restored",
"task_secret": "ts-i-am-restored",
"function_id": "fu-i-am-restored",
}
),
encoding="utf-8",
)
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
Expand Down
5 changes: 4 additions & 1 deletion test/container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,7 @@ def test_checkpointing_cls_function(servicer):
ret = _run_container(
servicer,
"test.supports.functions",
"CheckpointingCls.*",
"SnapshottingCls.*",
inputs=_get_inputs((("D",), {}), method_name="f"),
is_checkpointing_function=True,
is_class=True,
Expand Down Expand Up @@ -819,6 +819,9 @@ def test_container_heartbeats(servicer):
_run_container(servicer, "test.supports.functions", "square")
assert any(isinstance(request, api_pb2.ContainerHeartbeatRequest) for request in servicer.requests)

_run_container(servicer, "test.supports.functions", "snapshotting_square")
assert any(isinstance(request, api_pb2.ContainerHeartbeatRequest) for request in servicer.requests)


@skip_github_non_linux
def test_cli(servicer):
Expand Down
7 changes: 6 additions & 1 deletion test/supports/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def f(self, x):


@app.cls(enable_memory_snapshot=True)
class CheckpointingCls:
class SnapshottingCls:
def __init__(self):
self._vals = []

Expand All @@ -414,6 +414,11 @@ def f(self, x):
return "".join(self._vals) + x


@app.function(enable_memory_snapshot=True)
def snapshotting_square(x):
return x * x


@app.cls()
class EventLoopCls:
@enter()
Expand Down
Loading