diff --git a/modal/_output.py b/modal/_output.py index 31a3d8adfa..a0efae5e94 100644 --- a/modal/_output.py +++ b/modal/_output.py @@ -167,6 +167,7 @@ class OutputManager: _app_page_url: Optional[str] _show_image_logs: bool _status_spinner_live: Optional[Live] + _tree: Optional[Tree] def __init__( self, @@ -187,6 +188,7 @@ def __init__( self._app_page_url = None self._show_image_logs = False self._status_spinner_live = None + self._tree = None @classmethod def disable(cls): @@ -385,6 +387,21 @@ def show_status_spinner(self): with self._status_spinner_live: yield + @contextlib.contextmanager + def make_tree(self): + self._tree = Tree(step_progress("Creating objects..."), guide_style="gray50") + + try: + with self.make_live(self._tree): + yield + self._tree.label = step_completed("Created objects.") + self.print(self._tree) + finally: + self._tree = None + + def add_status_row(self) -> "StatusRow": + return StatusRow(self._tree) + class ProgressHandler: live: Live @@ -677,13 +694,13 @@ class FunctionCreationStatus: tag: str response: Optional[api_pb2.FunctionCreateResponse] = None - def __init__(self, resolver, tag): - self.resolver = resolver + def __init__(self, tag): self.tag = tag def __enter__(self): - self.status_row = self.resolver.add_status_row() - self.status_row.message(f"Creating function {self.tag}...") + if output_mgr := OutputManager.get(): + self.status_row = output_mgr.add_status_row() + self.status_row.message(f"Creating function {self.tag}...") return self def set_response(self, resp: api_pb2.FunctionCreateResponse): @@ -694,7 +711,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): raise exc_val if not self.response: - self.status_row.finish(f"Unknown error when creating function {self.tag}") + if self.status_row: + self.status_row.finish(f"Unknown error when creating function {self.tag}") elif self.response.function.web_url: url_info = self.response.function.web_url_info @@ -713,11 +731,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): # Print custom domain in terminal for custom_domain in self.response.function.custom_domain_info: - custom_domain_status_row = self.resolver.add_status_row() - custom_domain_status_row.finish( - f"Custom domain for {self.tag} => [magenta underline]" - f"{custom_domain.url}[/magenta underline]{suffix}" - ) + if output_mgr := OutputManager.get(): + custom_domain_status_row = output_mgr.add_status_row() + custom_domain_status_row.finish( + f"Custom domain for {self.tag} => [magenta underline]" + f"{custom_domain.url}[/magenta underline]{suffix}" + ) else: self.status_row.finish(f"Created function {self.tag}.") diff --git a/modal/_resolver.py b/modal/_resolver.py index 31d65775fd..ae3793918e 100644 --- a/modal/_resolver.py +++ b/modal/_resolver.py @@ -1,6 +1,5 @@ # Copyright Modal Labs 2023 import asyncio -import contextlib from asyncio import Future from typing import TYPE_CHECKING, Dict, Hashable, List, Optional @@ -12,7 +11,6 @@ from .exception import NotFoundError if TYPE_CHECKING: - from modal._output import StatusRow from modal.object import _Object @@ -30,12 +28,7 @@ def __init__( environment_name: Optional[str] = None, app_id: Optional[str] = None, ): - from rich.tree import Tree - - from ._output import step_progress - self._local_uuid_to_future = {} - self._tree = Tree(step_progress("Creating objects..."), guide_style="gray50") self._client = client self._app_id = app_id self._environment_name = environment_name @@ -130,21 +123,3 @@ def objects(self) -> List["_Object"]: obj = fut.result() unique_objects.setdefault(obj.object_id, obj) return list(unique_objects.values()) - - @contextlib.contextmanager - def display(self): - # TODO(erikbern): get rid of this wrapper - from ._output import OutputManager, step_completed - - if output_mgr := OutputManager.get(): - with output_mgr.make_live(self._tree): - yield - self._tree.label = step_completed("Created objects.") - output_mgr.print(self._tree) - else: - yield - - def add_status_row(self) -> "StatusRow": - from ._output import StatusRow - - return StatusRow(self._tree) diff --git a/modal/functions.py b/modal/functions.py index a01a44ada1..0978e3fef8 100644 --- a/modal/functions.py +++ b/modal/functions.py @@ -353,7 +353,7 @@ async def _load(method_bound_function: "_Function", resolver: Resolver, existing existing_function_id=existing_object_id or method_bound_function.object_id or "", ) assert resolver.client.stub is not None # client should be connected when load is called - with FunctionCreationStatus(resolver, full_name) as function_creation_status: + with FunctionCreationStatus(full_name) as function_creation_status: response = await resolver.client.stub.FunctionCreate(request) method_bound_function._hydrate( response.function_id, @@ -715,7 +715,7 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]): assert resolver.client and resolver.client.stub - with FunctionCreationStatus(resolver, tag) as function_creation_status: + with FunctionCreationStatus(tag) as function_creation_status: if is_generator: function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR else: diff --git a/modal/mount.py b/modal/mount.py index 4344e6b5a8..535d7c8037 100644 --- a/modal/mount.py +++ b/modal/mount.py @@ -19,6 +19,7 @@ from modal_proto import api_pb2 from modal_version import __version__ +from ._output import OutputManager from ._resolver import Resolver from ._utils.async_utils import synchronize_api from ._utils.blob_utils import FileUploadSpec, blob_upload_file, get_file_upload_spec_from_path @@ -429,12 +430,16 @@ async def _load_mount( accounted_hashes: set[str] = set() message_label = _Mount._description(self._entries) blob_upload_concurrency = asyncio.Semaphore(16) # Limit uploads of large files. - status_row = resolver.add_status_row() + if output_mgr := OutputManager.get(): + status_row = output_mgr.add_status_row() + else: + status_row = None async def _put_file(file_spec: FileUploadSpec) -> api_pb2.MountFile: nonlocal n_seen, n_finished, total_uploads, total_bytes n_seen += 1 - status_row.message(f"Creating mount {message_label}: Uploaded {n_finished}/{n_seen} files") + if status_row: + status_row.message(f"Creating mount {message_label}: Uploaded {n_finished}/{n_seen} files") remote_filename = file_spec.mount_filename mount_file = api_pb2.MountFile( @@ -492,7 +497,8 @@ async def _put_file(file_spec: FileUploadSpec) -> api_pb2.MountFile: logger.warning(f"Mount of '{message_label}' is empty.") # Build the mount. - status_row.message(f"Creating mount {message_label}: Finalizing index of {len(files)} files") + if status_row: + status_row.message(f"Creating mount {message_label}: Finalizing index of {len(files)} files") if self._deployment_name: req = api_pb2.MountGetOrCreateRequest( deployment_name=self._deployment_name, @@ -515,7 +521,8 @@ async def _put_file(file_spec: FileUploadSpec) -> api_pb2.MountFile: ) resp = await retry_transient_errors(resolver.client.stub.MountGetOrCreate, req, base_delay=1) - status_row.finish(f"Created mount {message_label}") + if status_row: + status_row.finish(f"Created mount {message_label}") logger.debug(f"Uploaded {total_uploads} new files and {total_bytes} bytes in {time.monotonic() - t0}s") self._hydrate(resp.mount_id, resolver.client, resp.handle_metadata) diff --git a/modal/runner.py b/modal/runner.py index 116fa063d0..935864f43b 100644 --- a/modal/runner.py +++ b/modal/runner.py @@ -2,6 +2,7 @@ import asyncio import dataclasses import os +from contextlib import contextmanager from multiprocessing.synchronize import Event from typing import TYPE_CHECKING, Any, AsyncGenerator, Coroutine, Dict, List, Optional, TypeVar @@ -100,6 +101,15 @@ async def _init_local_app_from_name( ) +@contextmanager +def display(): + if output_mgr := OutputManager.get(): + with output_mgr.make_tree(): + yield + else: + yield + + async def _create_all_objects( client: _Client, running_app: RunningApp, @@ -116,7 +126,7 @@ async def _create_all_objects( environment_name=environment_name, app_id=running_app.app_id, ) - with resolver.display(): + with display(): # Get current objects, and reset all objects tag_to_object_id = running_app.tag_to_object_id running_app.tag_to_object_id = {}