diff --git a/modal/_container_io_manager.py b/modal/_container_io_manager.py index eaa16416e..641aac6c6 100644 --- a/modal/_container_io_manager.py +++ b/modal/_container_io_manager.py @@ -4,6 +4,7 @@ import math import os import signal +import sys import time import traceback from dataclasses import dataclass @@ -601,6 +602,20 @@ async def memory_restore(self) -> None: self.current_input_id = None self.current_input_started_at = None + # Patch torch to ensure it doesn't return CUDA unavailibility due to + # cached queries that executed during snapshot process. ref: MOD-3257 + # + # perf: scanning sys.modules keys before import to avoid slow PYTHONPATH scanning. + if "torch" in sys.modules: + try: + sys.modules["torch"].cuda.device_count = sys.modules["torch"].cuda._device_count_nvml + # Wide-open except to catch anything. We don't want to crash here. + except Exception as exc: + logger.warning( + f"failed to patch 'torch.cuda.device_count' during snapshot restore: {exc}. " + "CUDA device availability may be inaccurate." + ) + self._client = await _Client.from_env() self._waiting_for_memory_snapshot = False diff --git a/modal_proto/api.proto b/modal_proto/api.proto index 2fc0a6020..2e95a92f1 100644 --- a/modal_proto/api.proto +++ b/modal_proto/api.proto @@ -87,9 +87,7 @@ enum CloudProvider { CLOUD_PROVIDER_GCP = 2; CLOUD_PROVIDER_AUTO = 3; CLOUD_PROVIDER_OCI = 4; - CLOUD_PROVIDER_LAMBDA_LABS = 5; - CLOUD_PROVIDER_FLUIDSTACK = 6; // experimental - CLOUD_PROVIDER_LATITUDE = 7; // experimental + reserved 5, 6, 7; // now unused internal experimental values } enum DNSRecordType { diff --git a/modal_version/_version_generated.py b/modal_version/_version_generated.py index effb52baa..f87678aae 100644 --- a/modal_version/_version_generated.py +++ b/modal_version/_version_generated.py @@ -1,4 +1,4 @@ # Copyright Modal Labs 2024 # Note: Reset this value to -1 whenever you make a minor `0.X` release of the client. -build_number = 51 # git: 09034a1 +build_number = 52 # git: 96ce76a diff --git a/test/container_app_test.py b/test/container_app_test.py index f9d0b9775..90f753e02 100644 --- a/test/container_app_test.py +++ b/test/container_app_test.py @@ -86,6 +86,95 @@ async def test_container_debug_snapshot(container_client, tmpdir, servicer): test_breakpoint.assert_called_once() +@pytest.fixture(scope="function") +def fake_torch_module(): + module_path = os.path.join(os.getcwd(), "torch.py") + with open(module_path, "w") as f: + f.write( + """ +import dataclasses +@dataclasses.dataclass +class CUDA: + device_count = lambda self: 0 + _device_count_nvml = lambda self: 2 + +cuda = CUDA() +""" + ) + + yield module_path + # Teardown: remove the torch.py file + os.remove(module_path) + + +@pytest.fixture(scope="function") +def weird_torch_module(): + module_path = os.path.join(os.getcwd(), "torch.py") + with open(module_path, "w") as f: + f.write("IM_WEIRD = 42\n") + + yield module_path + + os.remove(module_path) # Teardown: remove the torch.py file + + +@pytest.mark.asyncio +async def test_container_snapshot_patching(fake_torch_module, container_client, tmpdir, servicer): + io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client) + restore_path = tmpdir.join("fake-restore-state.json") + + # bring fake torch into scope and call the utility fn + import torch + + assert torch.cuda.device_count() == 0 + + # Write out a restore file so that snapshot+restore will complete + restore_path.write_text( + json.dumps( + { + "task_id": "ta-i-am-restored", + "task_secret": "ts-i-am-restored", + "function_id": "fu-i-am-restored", + } + ), + encoding="utf-8", + ) + with mock.patch.dict( + os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} + ): + io_manager.memory_snapshot() + import torch + + assert torch.cuda.device_count() == 2 + + +@pytest.mark.asyncio +async def test_container_snapshot_patching_err(weird_torch_module, container_client, tmpdir, servicer): + io_manager = ContainerIOManager(api_pb2.ContainerArguments(), container_client) + restore_path = tmpdir.join("fake-restore-state.json") + + # bring weird torch into scope and call the utility fn + import torch as trch + + assert trch.IM_WEIRD == 42 + + # Write out a restore file so that snapshot+restore will complete + restore_path.write_text( + json.dumps( + { + "task_id": "ta-i-am-restored", + "task_secret": "ts-i-am-restored", + "function_id": "fu-i-am-restored", + } + ), + encoding="utf-8", + ) + with mock.patch.dict( + os.environ, {"MODAL_RESTORE_STATE_PATH": str(restore_path), "MODAL_SERVER_URL": servicer.container_addr} + ): + io_manager.memory_snapshot() # should not crash + + def test_interact(container_client, servicer): # Initialize container singleton ContainerIOManager(api_pb2.ContainerArguments(), container_client)