Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove all output mgmt from the resolver #2060

Merged
merged 5 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions modal/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from rich.spinner import Spinner
from rich.text import Text
from rich.tree import Tree

from modal_proto import api_pb2

Expand Down Expand Up @@ -74,6 +75,24 @@ def substep_completed(message: str) -> RenderableType:
return f"🔨 {message}"


class StatusRow:
def __init__(self, progress: "Optional[Tree]"):
self._spinner = None
self._step_node = None
if progress is not None:
self._spinner = step_progress()
self._step_node = progress.add(self._spinner)

def message(self, message):
if self._spinner is not None:
self._spinner.update(text=message)

def finish(self, message):
if self._step_node is not None:
self._spinner.update(text=message)
self._step_node.label = substep_completed(message)


def download_progress_bar() -> Progress:
"""
Returns a progress bar suitable for showing file download progress.
Expand Down Expand Up @@ -135,6 +154,7 @@ def finalize(self):

class OutputManager:
_instance: ClassVar[Optional["OutputManager"]] = None
_tree: ClassVar[Optional[Tree]] = None

_console: Console
_task_states: Dict[str, int]
Expand Down Expand Up @@ -366,6 +386,27 @@ def show_status_spinner(self):
with self._status_spinner_live:
yield

@classmethod
@contextlib.contextmanager
def make_tree(cls):
# Note: If the output isn't enabled, don't actually show the tree.
cls._tree = Tree(step_progress("Creating objects..."), guide_style="gray50")

if output_mgr := OutputManager.get():
with output_mgr.make_live(cls._tree):
yield
cls._tree.label = step_completed("Created objects.")
output_mgr.print(output_mgr._tree)
else:
yield

@classmethod
def add_status_row(cls) -> "StatusRow":
# Return a status row to be used for object creation.
# If output isn't enabled, the status row might be invisible.
assert cls._tree, "Output manager has no tree yet"
return StatusRow(cls._tree)


class ProgressHandler:
live: Live
Expand Down Expand Up @@ -658,12 +699,11 @@ class FunctionCreationStatus:
tag: str
response: Optional[api_pb2.FunctionCreateResponse] = None

def __init__(self, resolver, tag):
self.resolver = resolver
def __init__(self, tag):
self.tag = tag

def __enter__(self):
self.status_row = self.resolver.add_status_row()
self.status_row = OutputManager.add_status_row()
self.status_row.message(f"Creating function {self.tag}...")
return self

Expand Down Expand Up @@ -694,7 +734,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):

# Print custom domain in terminal
for custom_domain in self.response.function.custom_domain_info:
custom_domain_status_row = self.resolver.add_status_row()
custom_domain_status_row = OutputManager.add_status_row()
custom_domain_status_row.finish(
f"Custom domain for {self.tag} => [magenta underline]"
f"{custom_domain.url}[/magenta underline]{suffix}"
Expand Down
48 changes: 0 additions & 48 deletions modal/_resolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright Modal Labs 2023
import asyncio
import contextlib
from asyncio import Future
from typing import TYPE_CHECKING, Dict, Hashable, List, Optional

Expand All @@ -12,35 +11,9 @@
from .exception import NotFoundError

if TYPE_CHECKING:
from rich.tree import Tree

from modal.object import _Object


class StatusRow:
def __init__(self, progress: "Optional[Tree]"):
from ._output import (
step_progress,
)

self._spinner = None
self._step_node = None
if progress is not None:
self._spinner = step_progress()
self._step_node = progress.add(self._spinner)

def message(self, message):
if self._spinner is not None:
self._spinner.update(text=message)

def finish(self, message):
from ._output import substep_completed

if self._step_node is not None:
self._spinner.update(text=message)
self._step_node.label = substep_completed(message)


class Resolver:
_local_uuid_to_future: Dict[str, Future]
_environment_name: Optional[str]
Expand All @@ -55,12 +28,7 @@ def __init__(
environment_name: Optional[str] = None,
app_id: Optional[str] = None,
):
from rich.tree import Tree

from ._output import step_progress

self._local_uuid_to_future = {}
self._tree = Tree(step_progress("Creating objects..."), guide_style="gray50")
self._client = client
self._app_id = app_id
self._environment_name = environment_name
Expand Down Expand Up @@ -155,19 +123,3 @@ def objects(self) -> List["_Object"]:
obj = fut.result()
unique_objects.setdefault(obj.object_id, obj)
return list(unique_objects.values())

@contextlib.contextmanager
def display(self):
# TODO(erikbern): get rid of this wrapper
from ._output import OutputManager, step_completed

if output_mgr := OutputManager.get():
with output_mgr.make_live(self._tree):
yield
self._tree.label = step_completed("Created objects.")
output_mgr.print(self._tree)
else:
yield

def add_status_row(self) -> StatusRow:
return StatusRow(self._tree)
4 changes: 2 additions & 2 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ async def _load(method_bound_function: "_Function", resolver: Resolver, existing
existing_function_id=existing_object_id or method_bound_function.object_id or "",
)
assert resolver.client.stub is not None # client should be connected when load is called
with FunctionCreationStatus(resolver, full_name) as function_creation_status:
with FunctionCreationStatus(full_name) as function_creation_status:
response = await resolver.client.stub.FunctionCreate(request)
method_bound_function._hydrate(
response.function_id,
Expand Down Expand Up @@ -715,7 +715,7 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti

async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
assert resolver.client and resolver.client.stub
with FunctionCreationStatus(resolver, tag) as function_creation_status:
with FunctionCreationStatus(tag) as function_creation_status:
if is_generator:
function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR
else:
Expand Down
3 changes: 2 additions & 1 deletion modal/mount.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from modal_proto import api_pb2
from modal_version import __version__

from ._output import OutputManager
from ._resolver import Resolver
from ._utils.async_utils import synchronize_api
from ._utils.blob_utils import FileUploadSpec, blob_upload_file, get_file_upload_spec_from_path
Expand Down Expand Up @@ -429,7 +430,7 @@ async def _load_mount(
accounted_hashes: set[str] = set()
message_label = _Mount._description(self._entries)
blob_upload_concurrency = asyncio.Semaphore(16) # Limit uploads of large files.
status_row = resolver.add_status_row()
status_row = OutputManager.add_status_row()

async def _put_file(file_spec: FileUploadSpec) -> api_pb2.MountFile:
nonlocal n_seen, n_finished, total_uploads, total_bytes
Expand Down
2 changes: 1 addition & 1 deletion modal/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ async def _create_all_objects(
environment_name=environment_name,
app_id=running_app.app_id,
)
with resolver.display():
with OutputManager.make_tree():
# Get current objects, and reset all objects
tag_to_object_id = running_app.tag_to_object_id
running_app.tag_to_object_id = {}
Expand Down
Loading