Skip to content

Commit

Permalink
Move the remaining resolver output mgmt to _output
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Jul 30, 2024
1 parent 393a4f7 commit d03897d
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 42 deletions.
39 changes: 29 additions & 10 deletions modal/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class OutputManager:
_app_page_url: Optional[str]
_show_image_logs: bool
_status_spinner_live: Optional[Live]
_tree: Optional[Tree]

def __init__(
self,
Expand All @@ -187,6 +188,7 @@ def __init__(
self._app_page_url = None
self._show_image_logs = False
self._status_spinner_live = None
self._tree = None

@classmethod
def disable(cls):
Expand Down Expand Up @@ -385,6 +387,21 @@ def show_status_spinner(self):
with self._status_spinner_live:
yield

@contextlib.contextmanager
def make_tree(self):
self._tree = Tree(step_progress("Creating objects..."), guide_style="gray50")

try:
with self.make_live(self._tree):
yield
self._tree.label = step_completed("Created objects.")
self.print(self._tree)
finally:
self._tree = None

def add_status_row(self) -> "StatusRow":
return StatusRow(self._tree)


class ProgressHandler:
live: Live
Expand Down Expand Up @@ -677,13 +694,13 @@ class FunctionCreationStatus:
tag: str
response: Optional[api_pb2.FunctionCreateResponse] = None

def __init__(self, resolver, tag):
self.resolver = resolver
def __init__(self, tag):
self.tag = tag

def __enter__(self):
self.status_row = self.resolver.add_status_row()
self.status_row.message(f"Creating function {self.tag}...")
if output_mgr := OutputManager.get():
self.status_row = output_mgr.add_status_row()
self.status_row.message(f"Creating function {self.tag}...")
return self

def set_response(self, resp: api_pb2.FunctionCreateResponse):
Expand All @@ -694,7 +711,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
raise exc_val

if not self.response:
self.status_row.finish(f"Unknown error when creating function {self.tag}")
if self.status_row:
self.status_row.finish(f"Unknown error when creating function {self.tag}")

elif self.response.function.web_url:
url_info = self.response.function.web_url_info
Expand All @@ -713,11 +731,12 @@ def __exit__(self, exc_type, exc_val, exc_tb):

# Print custom domain in terminal
for custom_domain in self.response.function.custom_domain_info:
custom_domain_status_row = self.resolver.add_status_row()
custom_domain_status_row.finish(
f"Custom domain for {self.tag} => [magenta underline]"
f"{custom_domain.url}[/magenta underline]{suffix}"
)
if output_mgr := OutputManager.get():
custom_domain_status_row = output_mgr.add_status_row()
custom_domain_status_row.finish(
f"Custom domain for {self.tag} => [magenta underline]"
f"{custom_domain.url}[/magenta underline]{suffix}"
)
else:
self.status_row.finish(f"Created function {self.tag}.")

Expand Down
25 changes: 0 additions & 25 deletions modal/_resolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright Modal Labs 2023
import asyncio
import contextlib
from asyncio import Future
from typing import TYPE_CHECKING, Dict, Hashable, List, Optional

Expand All @@ -12,7 +11,6 @@
from .exception import NotFoundError

if TYPE_CHECKING:
from modal._output import StatusRow
from modal.object import _Object


Expand All @@ -30,12 +28,7 @@ def __init__(
environment_name: Optional[str] = None,
app_id: Optional[str] = None,
):
from rich.tree import Tree

from ._output import step_progress

self._local_uuid_to_future = {}
self._tree = Tree(step_progress("Creating objects..."), guide_style="gray50")
self._client = client
self._app_id = app_id
self._environment_name = environment_name
Expand Down Expand Up @@ -130,21 +123,3 @@ def objects(self) -> List["_Object"]:
obj = fut.result()
unique_objects.setdefault(obj.object_id, obj)
return list(unique_objects.values())

@contextlib.contextmanager
def display(self):
# TODO(erikbern): get rid of this wrapper
from ._output import OutputManager, step_completed

if output_mgr := OutputManager.get():
with output_mgr.make_live(self._tree):
yield
self._tree.label = step_completed("Created objects.")
output_mgr.print(self._tree)
else:
yield

def add_status_row(self) -> "StatusRow":
from ._output import StatusRow

return StatusRow(self._tree)
4 changes: 2 additions & 2 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ async def _load(method_bound_function: "_Function", resolver: Resolver, existing
existing_function_id=existing_object_id or method_bound_function.object_id or "",
)
assert resolver.client.stub is not None # client should be connected when load is called
with FunctionCreationStatus(resolver, full_name) as function_creation_status:
with FunctionCreationStatus(full_name) as function_creation_status:
response = await resolver.client.stub.FunctionCreate(request)
method_bound_function._hydrate(
response.function_id,
Expand Down Expand Up @@ -715,7 +715,7 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti

async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
assert resolver.client and resolver.client.stub
with FunctionCreationStatus(resolver, tag) as function_creation_status:
with FunctionCreationStatus(tag) as function_creation_status:
if is_generator:
function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR
else:
Expand Down
15 changes: 11 additions & 4 deletions modal/mount.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from modal_proto import api_pb2
from modal_version import __version__

from ._output import OutputManager
from ._resolver import Resolver
from ._utils.async_utils import synchronize_api
from ._utils.blob_utils import FileUploadSpec, blob_upload_file, get_file_upload_spec_from_path
Expand Down Expand Up @@ -429,12 +430,16 @@ async def _load_mount(
accounted_hashes: set[str] = set()
message_label = _Mount._description(self._entries)
blob_upload_concurrency = asyncio.Semaphore(16) # Limit uploads of large files.
status_row = resolver.add_status_row()
if output_mgr := OutputManager.get():
status_row = output_mgr.add_status_row()
else:
status_row = None

async def _put_file(file_spec: FileUploadSpec) -> api_pb2.MountFile:
nonlocal n_seen, n_finished, total_uploads, total_bytes
n_seen += 1
status_row.message(f"Creating mount {message_label}: Uploaded {n_finished}/{n_seen} files")
if status_row:
status_row.message(f"Creating mount {message_label}: Uploaded {n_finished}/{n_seen} files")

remote_filename = file_spec.mount_filename
mount_file = api_pb2.MountFile(
Expand Down Expand Up @@ -492,7 +497,8 @@ async def _put_file(file_spec: FileUploadSpec) -> api_pb2.MountFile:
logger.warning(f"Mount of '{message_label}' is empty.")

# Build the mount.
status_row.message(f"Creating mount {message_label}: Finalizing index of {len(files)} files")
if status_row:
status_row.message(f"Creating mount {message_label}: Finalizing index of {len(files)} files")
if self._deployment_name:
req = api_pb2.MountGetOrCreateRequest(
deployment_name=self._deployment_name,
Expand All @@ -515,7 +521,8 @@ async def _put_file(file_spec: FileUploadSpec) -> api_pb2.MountFile:
)

resp = await retry_transient_errors(resolver.client.stub.MountGetOrCreate, req, base_delay=1)
status_row.finish(f"Created mount {message_label}")
if status_row:
status_row.finish(f"Created mount {message_label}")

logger.debug(f"Uploaded {total_uploads} new files and {total_bytes} bytes in {time.monotonic() - t0}s")
self._hydrate(resp.mount_id, resolver.client, resp.handle_metadata)
Expand Down
12 changes: 11 additions & 1 deletion modal/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import asyncio
import dataclasses
import os
from contextlib import contextmanager
from multiprocessing.synchronize import Event
from typing import TYPE_CHECKING, Any, AsyncGenerator, Coroutine, Dict, List, Optional, TypeVar

Expand Down Expand Up @@ -100,6 +101,15 @@ async def _init_local_app_from_name(
)


@contextmanager
def display():
if output_mgr := OutputManager.get():
with output_mgr.make_tree():
yield
else:
yield


async def _create_all_objects(
client: _Client,
running_app: RunningApp,
Expand All @@ -116,7 +126,7 @@ async def _create_all_objects(
environment_name=environment_name,
app_id=running_app.app_id,
)
with resolver.display():
with display():
# Get current objects, and reset all objects
tag_to_object_id = running_app.tag_to_object_id
running_app.tag_to_object_id = {}
Expand Down

0 comments on commit d03897d

Please sign in to comment.