Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MOD-3251] Stop heartbeats before sending the container snapshot RPC #2004

Merged
merged 28 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f3c04c6
Stop heartbeats before snapshotting
mattnappo Jul 12, 2024
41adde8
Merge branch 'main' into matt/gvisor-flake-fix
thundergolfer Jul 12, 2024
4cceb72
Added asyncio.Event for pausing heartbeats
mattnappo Jul 12, 2024
2cdcd60
Merge branch 'matt/gvisor-flake-fix' of github.com:modal-labs/modal-c…
mattnappo Jul 12, 2024
626af1c
Use asyncio.Condition to ensure mutual exclusion
mattnappo Jul 12, 2024
7228652
Fixed unit tests
mattnappo Jul 12, 2024
b0bef65
Merge branch 'main' of github.com:modal-labs/modal-client into matt/g…
mattnappo Jul 15, 2024
86b8fd5
Fixed bug, but now its slow
mattnappo Jul 15, 2024
4c6ed72
Fixed bottleneck
mattnappo Jul 15, 2024
096a399
Remove old logic
mattnappo Jul 15, 2024
bf5400b
Renamed cond var
mattnappo Jul 15, 2024
484e24d
Wrote tests
mattnappo Jul 15, 2024
b6fa509
Wrote better tests
mattnappo Jul 15, 2024
6071891
Try to fix tests
mattnappo Jul 15, 2024
70c49ff
Retry
mattnappo Jul 15, 2024
5b423f5
Undo warning
mattnappo Jul 15, 2024
817709f
Await
mattnappo Jul 15, 2024
3259b2b
Try using async client
mattnappo Jul 15, 2024
3c3209e
Use servicer
mattnappo Jul 15, 2024
a9070ff
Use servicer
mattnappo Jul 15, 2024
153eb0a
Lint
mattnappo Jul 15, 2024
5bb4044
Add sleep in test
mattnappo Jul 15, 2024
91b5665
Reduce sleep time
mattnappo Jul 15, 2024
7e7d692
Merge branch 'main' into matt/gvisor-flake-fix
thundergolfer Jul 16, 2024
74ae84a
Address review
mattnappo Jul 16, 2024
7bbae23
Merge branch 'matt/gvisor-flake-fix' of github.com:modal-labs/modal-c…
mattnappo Jul 16, 2024
c0f5fcb
Run entire restore phase within the lock
mattnappo Jul 16, 2024
87a9036
Merge branch 'main' into matt/gvisor-flake-fix
thundergolfer Jul 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion modal/_container_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):

_client: _Client = synchronizer._translate_in(client) # TODO(erikbern): ugly

with container_io_manager.heartbeats(), UserCodeEventLoop() as event_loop:
with container_io_manager.heartbeats(function_def.is_checkpointing_function), UserCodeEventLoop() as event_loop:
mattnappo marked this conversation as resolved.
Show resolved Hide resolved
# If this is a serialized function, fetch the definition from the server
if function_def.definition_type == api_pb2.Function.DEFINITION_TYPE_SERIALIZED:
ser_cls, ser_fun = container_io_manager.get_serialized_function()
Expand Down Expand Up @@ -779,6 +779,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
if function_def.is_checkpointing_function:
container_io_manager.memory_snapshot()


# Install hooks for interactive functions.
if function_def.pty_info.pty_type != api_pb2.PTYInfo.PTY_TYPE_UNSPECIFIED:

Expand Down
39 changes: 24 additions & 15 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ class _ContainerIOManager:
_input_concurrency: Optional[int]
_semaphore: Optional[asyncio.Semaphore]
_environment_name: str
_waiting_for_memory_snapshot: bool
_heartbeat_loop: Optional[asyncio.Task]
_pause_heartbeats: Optional[asyncio.Condition]
mattnappo marked this conversation as resolved.
Show resolved Hide resolved
_waiting_for_memory_snapshot: bool

_is_interactivity_enabled: bool
_fetching_inputs: bool
Expand All @@ -101,8 +102,9 @@ def _init(self, container_args: api_pb2.ContainerArguments, client: _Client):

self._semaphore = None
self._environment_name = container_args.environment_name
self._waiting_for_memory_snapshot = False
self._heartbeat_loop = None
self._pause_heartbeats = asyncio.Condition()
mattnappo marked this conversation as resolved.
Show resolved Hide resolved
self._waiting_for_memory_snapshot = False

self._is_interactivity_enabled = False
self._fetching_inputs = True
Expand Down Expand Up @@ -146,22 +148,20 @@ async def _heartbeat_handle_cancellations(self) -> bool:
# Return True if a cancellation event was received, in that case
# we shouldn't wait too long for another heartbeat

# Don't send heartbeats for tasks waiting to be checkpointed.
# Calling gRPC methods open new connections which block the
# checkpointing process.
if self._waiting_for_memory_snapshot:
return False

request = api_pb2.ContainerHeartbeatRequest(supports_graceful_input_cancellation=True)
if self.current_input_id is not None:
request.current_input_id = self.current_input_id
if self.current_input_started_at is not None:
request.current_input_started_at = self.current_input_started_at

# TODO(erikbern): capture exceptions?
response = await retry_transient_errors(
self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
)
async with self._pause_heartbeats:
while self._waiting_for_memory_snapshot:
thundergolfer marked this conversation as resolved.
Show resolved Hide resolved
await self._pause_heartbeats.wait()

# TODO(erikbern): capture exceptions?
response = await retry_transient_errors(
self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
)

if response.HasField("cancel_input_event"):
# Pause processing of the current input by signaling self a SIGUSR1.
Expand Down Expand Up @@ -196,10 +196,11 @@ async def _heartbeat_handle_cancellations(self) -> bool:
return False

@asynccontextmanager
async def heartbeats(self) -> AsyncGenerator[None, None]:
async def heartbeats(self, disable_init: bool) -> AsyncGenerator[None, None]:
mattnappo marked this conversation as resolved.
Show resolved Hide resolved
async with TaskContext() as tc:
self._heartbeat_loop = t = tc.create_task(self._run_heartbeat_loop())
t.set_name("heartbeat loop")
self._waiting_for_memory_snapshot = disable_init
try:
yield
finally:
Expand Down Expand Up @@ -574,6 +575,11 @@ async def memory_restore(self) -> None:
await asyncio.sleep(0.01)
continue

# Turn heartbeats back on
mattnappo marked this conversation as resolved.
Show resolved Hide resolved
async with self._pause_heartbeats:
self._waiting_for_memory_snapshot = False
self._pause_heartbeats.notify_all()

logger.debug("Container: restored")

# Look for state file and create new client with updated credentials.
Expand Down Expand Up @@ -617,18 +623,21 @@ async def memory_restore(self) -> None:
)

self._client = await _Client.from_env()
self._waiting_for_memory_snapshot = False

async def memory_snapshot(self) -> None:
"""Message server indicating that function is ready to be checkpointed."""
if self.checkpoint_id:
logger.debug(f"Checkpoint ID: {self.checkpoint_id} (Memory Snapshot ID)")

# Pause heartbeats since they keep the client connection open which causes the snapshotter to crash
async with self._pause_heartbeats:
self._waiting_for_memory_snapshot = True
self._pause_heartbeats.notify_all()

await self._client.stub.ContainerCheckpoint(
api_pb2.ContainerCheckpointRequest(checkpoint_id=self.checkpoint_id)
)

self._waiting_for_memory_snapshot = True
await self._client._close(forget_credentials=True)

logger.debug("Memory snapshot request sent. Connection closed.")
Expand Down
34 changes: 33 additions & 1 deletion test/container_app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from typing import Dict
from unittest import mock
import asyncio

from google.protobuf.empty_pb2 import Empty
from google.protobuf.message import Message
Expand Down Expand Up @@ -58,12 +59,43 @@ async def test_container_snapshot_restore(container_client, tmpdir, servicer):
encoding="utf-8",
)
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr},
):
io_manager.memory_snapshot()
# In-memory Client instance should have update credentials, not old credentials
assert old_client.credentials == ("ta-i-am-restored", "ts-i-am-restored")

@pytest.mark.asyncio
async def test_container_snapshot_restore_heartbeats(container_client, tmpdir, servicer):
io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client)
restore_path = tmpdir.join("fake-restore-state.json")
restore_path.write_text(
json.dumps(
{
"task_id": "ta-i-am-restored",
"task_secret": "ts-i-am-restored",
"function_id": "fu-i-am-restored",
}
),
encoding="utf-8",
)

asyncio.set_event_loop(asyncio.new_event_loop())
mattnappo marked this conversation as resolved.
Show resolved Hide resolved

# Ensure that heartbeats do not run before the snapshot
# Ensure that heartbeats do run after the snapshot
with io_manager.heartbeats(True):
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr},
):
with mock.patch("modal.runner.HEARTBEAT_INTERVAL", 1):
with mock.patch.object(container_client.stub, 'ContainerHeartbeat') as mock_heartbeat:

await asyncio.sleep(2)
mock_heartbeat.assert_not_called()
io_manager.memory_snapshot()
mock_heartbeat.assert_called_once()
mattnappo marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.asyncio
async def test_container_debug_snapshot(container_client, tmpdir, servicer):
Expand Down
3 changes: 3 additions & 0 deletions test/container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,9 @@ def test_container_heartbeats(servicer):
_run_container(servicer, "test.supports.functions", "square")
assert any(isinstance(request, api_pb2.ContainerHeartbeatRequest) for request in servicer.requests)

_run_container(servicer, "test.supports.functions", "snapshotting_square")
assert any(isinstance(request, api_pb2.ContainerHeartbeatRequest) for request in servicer.requests)


@skip_github_non_linux
def test_cli(servicer):
Expand Down
1 change: 1 addition & 0 deletions test/live_reload_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,4 @@ async def test_heartbeats(app_ref, server_url_env, servicer):
# Typically [0s, 1s, 2s, 3s], but asyncio.sleep may lag.
actual_heartbeats = servicer.app_heartbeats[apps[0]]
assert abs(actual_heartbeats - (total_secs + 1)) <= 1

5 changes: 5 additions & 0 deletions test/supports/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,11 @@ def f(self, x):
return "".join(self._vals) + x


@app.function(enable_memory_snapshot=True)
def snapshotting_square(x):
return x * x


@app.cls()
class EventLoopCls:
@enter()
Expand Down
Loading