diff --git a/modal/_output.py b/modal/_output.py index e16b95c37..6c771fb99 100644 --- a/modal/_output.py +++ b/modal/_output.py @@ -30,7 +30,6 @@ ) from rich.spinner import Spinner from rich.text import Text -from rich.tree import Tree from modal_proto import api_pb2 @@ -75,24 +74,6 @@ def substep_completed(message: str) -> RenderableType: return f"🔨 {message}" -class StatusRow: - def __init__(self, progress: "Optional[Tree]"): - self._spinner = None - self._step_node = None - if progress is not None: - self._spinner = step_progress() - self._step_node = progress.add(self._spinner) - - def message(self, message): - if self._spinner is not None: - self._spinner.update(text=message) - - def finish(self, message): - if self._step_node is not None: - self._spinner.update(text=message) - self._step_node.label = substep_completed(message) - - def download_progress_bar() -> Progress: """ Returns a progress bar suitable for showing file download progress. @@ -167,7 +148,6 @@ class OutputManager: _app_page_url: Optional[str] _show_image_logs: bool _status_spinner_live: Optional[Live] - _tree: Optional[Tree] def __init__( self, @@ -188,7 +168,6 @@ def __init__( self._app_page_url = None self._show_image_logs = False self._status_spinner_live = None - self._tree = None @classmethod def disable(cls): @@ -387,30 +366,6 @@ def show_status_spinner(self): with self._status_spinner_live: yield - @classmethod - @contextlib.contextmanager - def make_tree(cls): - if output_mgr := OutputManager.get(): - tree = output_mgr._tree = Tree(step_progress("Creating objects..."), guide_style="gray50") - with output_mgr.make_live(tree): - try: - yield - finally: - output_mgr._tree = None - tree.label = step_completed("Created objects.") - output_mgr.print(tree) - else: - yield - - @classmethod - def add_status_row(cls) -> "StatusRow": - # Return a status row to be used for object creation. - # If output isn't enabled, just create a hidden tree - if cls._instance and cls._instance._tree: - return StatusRow(cls._instance._tree) - else: - return StatusRow(Tree("")) - class ProgressHandler: live: Live @@ -703,11 +658,12 @@ class FunctionCreationStatus: tag: str response: Optional[api_pb2.FunctionCreateResponse] = None - def __init__(self, tag): + def __init__(self, resolver, tag): + self.resolver = resolver self.tag = tag def __enter__(self): - self.status_row = OutputManager.add_status_row() + self.status_row = self.resolver.add_status_row() self.status_row.message(f"Creating function {self.tag}...") return self @@ -738,7 +694,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): # Print custom domain in terminal for custom_domain in self.response.function.custom_domain_info: - custom_domain_status_row = OutputManager.add_status_row() + custom_domain_status_row = self.resolver.add_status_row() custom_domain_status_row.finish( f"Custom domain for {self.tag} => [magenta underline]" f"{custom_domain.url}[/magenta underline]{suffix}" diff --git a/modal/_resolver.py b/modal/_resolver.py index ae3793918..8fd69b3f8 100644 --- a/modal/_resolver.py +++ b/modal/_resolver.py @@ -1,5 +1,6 @@ # Copyright Modal Labs 2023 import asyncio +import contextlib from asyncio import Future from typing import TYPE_CHECKING, Dict, Hashable, List, Optional @@ -11,9 +12,35 @@ from .exception import NotFoundError if TYPE_CHECKING: + from rich.tree import Tree + from modal.object import _Object +class StatusRow: + def __init__(self, progress: "Optional[Tree]"): + from ._output import ( + step_progress, + ) + + self._spinner = None + self._step_node = None + if progress is not None: + self._spinner = step_progress() + self._step_node = progress.add(self._spinner) + + def message(self, message): + if self._spinner is not None: + self._spinner.update(text=message) + + def finish(self, message): + from ._output import substep_completed + + if self._step_node is not None: + self._spinner.update(text=message) + self._step_node.label = substep_completed(message) + + class Resolver: _local_uuid_to_future: Dict[str, Future] _environment_name: Optional[str] @@ -28,7 +55,12 @@ def __init__( environment_name: Optional[str] = None, app_id: Optional[str] = None, ): + from rich.tree import Tree + + from ._output import step_progress + self._local_uuid_to_future = {} + self._tree = Tree(step_progress("Creating objects..."), guide_style="gray50") self._client = client self._app_id = app_id self._environment_name = environment_name @@ -123,3 +155,19 @@ def objects(self) -> List["_Object"]: obj = fut.result() unique_objects.setdefault(obj.object_id, obj) return list(unique_objects.values()) + + @contextlib.contextmanager + def display(self): + # TODO(erikbern): get rid of this wrapper + from ._output import OutputManager, step_completed + + if output_mgr := OutputManager.get(): + with output_mgr.make_live(self._tree): + yield + self._tree.label = step_completed("Created objects.") + output_mgr.print(self._tree) + else: + yield + + def add_status_row(self) -> StatusRow: + return StatusRow(self._tree) diff --git a/modal/functions.py b/modal/functions.py index 0978e3fef..a01a44ada 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -353,7 +353,7 @@ async def _load(method_bound_function: "_Function", resolver: Resolver, existing existing_function_id=existing_object_id or method_bound_function.object_id or "", ) assert resolver.client.stub is not None # client should be connected when load is called - with FunctionCreationStatus(full_name) as function_creation_status: + with FunctionCreationStatus(resolver, full_name) as function_creation_status: response = await resolver.client.stub.FunctionCreate(request) method_bound_function._hydrate( response.function_id, @@ -715,7 +715,7 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]): assert resolver.client and resolver.client.stub - with FunctionCreationStatus(tag) as function_creation_status: + with FunctionCreationStatus(resolver, tag) as function_creation_status: if is_generator: function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR else: diff --git a/modal/mount.py b/modal/mount.py index 779c743e4..4344e6b5a 100644 --- a/modal/mount.py +++ b/modal/mount.py @@ -19,7 +19,6 @@ from modal_proto import api_pb2 from modal_version import __version__ -from ._output import OutputManager from ._resolver import Resolver from ._utils.async_utils import synchronize_api from ._utils.blob_utils import FileUploadSpec, blob_upload_file, get_file_upload_spec_from_path @@ -430,7 +429,7 @@ async def _load_mount( accounted_hashes: set[str] = set() message_label = _Mount._description(self._entries) blob_upload_concurrency = asyncio.Semaphore(16) # Limit uploads of large files. - status_row = OutputManager.add_status_row() + status_row = resolver.add_status_row() async def _put_file(file_spec: FileUploadSpec) -> api_pb2.MountFile: nonlocal n_seen, n_finished, total_uploads, total_bytes diff --git a/modal/runner.py b/modal/runner.py index 905a9db4b..116fa063d 100644 --- a/modal/runner.py +++ b/modal/runner.py @@ -116,7 +116,7 @@ async def _create_all_objects( environment_name=environment_name, app_id=running_app.app_id, ) - with OutputManager.make_tree(): + with resolver.display(): # Get current objects, and reset all objects tag_to_object_id = running_app.tag_to_object_id running_app.tag_to_object_id = {} diff --git a/test/conftest.py b/test/conftest.py index c5cf8cf7e..4bc10e3e0 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -30,7 +30,6 @@ import modal._serialization from modal import __version__, config from modal._container_io_manager import _ContainerIOManager -from modal._output import OutputManager from modal._serialization import serialize_data_format from modal._utils.async_utils import asyncify, synchronize_api from modal._utils.grpc_testing import patch_mock_servicer @@ -1552,11 +1551,6 @@ async def reset_default_client(): Client.set_env_client(None) -@pytest_asyncio.fixture(scope="function", autouse=True) -async def reset_output_manager(): - OutputManager._instance = None - - @pytest.fixture(name="mock_dir", scope="session") def mock_dir_factory(): """Sets up a temp dir with content as specified in a nested dict