Skip to content

Commit

Permalink
Merge branch 'main' into matt/gvisor-flake-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
thundergolfer committed Jul 12, 2024
2 parents f3c04c6 + bd5d738 commit 41adde8
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 4 deletions.
15 changes: 15 additions & 0 deletions modal/_container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
import os
import signal
import sys
import time
import traceback
from dataclasses import dataclass
Expand Down Expand Up @@ -601,6 +602,20 @@ async def memory_restore(self) -> None:
self.current_input_id = None
self.current_input_started_at = None

# Patch torch to ensure it doesn't return CUDA unavailibility due to
# cached queries that executed during snapshot process. ref: MOD-3257
#
# perf: scanning sys.modules keys before import to avoid slow PYTHONPATH scanning.
if "torch" in sys.modules:
try:
sys.modules["torch"].cuda.device_count = sys.modules["torch"].cuda._device_count_nvml
# Wide-open except to catch anything. We don't want to crash here.
except Exception as exc:
logger.warning(
f"failed to patch 'torch.cuda.device_count' during snapshot restore: {exc}. "
"CUDA device availability may be inaccurate."
)

self._client = await _Client.from_env()
self._waiting_for_memory_snapshot = False

Expand Down
4 changes: 1 addition & 3 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ enum CloudProvider {
CLOUD_PROVIDER_GCP = 2;
CLOUD_PROVIDER_AUTO = 3;
CLOUD_PROVIDER_OCI = 4;
CLOUD_PROVIDER_LAMBDA_LABS = 5;
CLOUD_PROVIDER_FLUIDSTACK = 6; // experimental
CLOUD_PROVIDER_LATITUDE = 7; // experimental
reserved 5, 6, 7; // now unused internal experimental values
}

enum DNSRecordType {
Expand Down
2 changes: 1 addition & 1 deletion modal_version/_version_generated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright Modal Labs 2024

# Note: Reset this value to -1 whenever you make a minor `0.X` release of the client.
build_number = 51 # git: 09034a1
build_number = 52 # git: 96ce76a
89 changes: 89 additions & 0 deletions test/container_app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,95 @@ async def test_container_debug_snapshot(container_client, tmpdir, servicer):
test_breakpoint.assert_called_once()


@pytest.fixture(scope="function")
def fake_torch_module():
module_path = os.path.join(os.getcwd(), "torch.py")
with open(module_path, "w") as f:
f.write(
"""
import dataclasses
@dataclasses.dataclass
class CUDA:
device_count = lambda self: 0
_device_count_nvml = lambda self: 2
cuda = CUDA()
"""
)

yield module_path
# Teardown: remove the torch.py file
os.remove(module_path)


@pytest.fixture(scope="function")
def weird_torch_module():
module_path = os.path.join(os.getcwd(), "torch.py")
with open(module_path, "w") as f:
f.write("IM_WEIRD = 42\n")

yield module_path

os.remove(module_path) # Teardown: remove the torch.py file


@pytest.mark.asyncio
async def test_container_snapshot_patching(fake_torch_module, container_client, tmpdir, servicer):
io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client)
restore_path = tmpdir.join("fake-restore-state.json")

# bring fake torch into scope and call the utility fn
import torch

assert torch.cuda.device_count() == 0

# Write out a restore file so that snapshot+restore will complete
restore_path.write_text(
json.dumps(
{
"task_id": "ta-i-am-restored",
"task_secret": "ts-i-am-restored",
"function_id": "fu-i-am-restored",
}
),
encoding="utf-8",
)
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
io_manager.memory_snapshot()
import torch

assert torch.cuda.device_count() == 2


@pytest.mark.asyncio
async def test_container_snapshot_patching_err(weird_torch_module, container_client, tmpdir, servicer):
io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client)
restore_path = tmpdir.join("fake-restore-state.json")

# bring weird torch into scope and call the utility fn
import torch as trch

assert trch.IM_WEIRD == 42

# Write out a restore file so that snapshot+restore will complete
restore_path.write_text(
json.dumps(
{
"task_id": "ta-i-am-restored",
"task_secret": "ts-i-am-restored",
"function_id": "fu-i-am-restored",
}
),
encoding="utf-8",
)
with mock.patch.dict(
os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr}
):
io_manager.memory_snapshot() # should not crash


def test_interact(container_client, servicer):
# Initialize container singleton
ContainerIOManager(api_pb2.ContainerArguments(), container_client)
Expand Down

0 comments on commit 41adde8

Please sign in to comment.