Skip to content

Commit

Permalink
Revert output mgr stuff (#2063)
Browse files Browse the repository at this point in the history
* Revert "Fix issue with global output manager tree (#2062)"

This reverts commit 2346aa9.

* Revert "Remove all output mgmt from the resolver (#2060)"

This reverts commit f1140b1.
  • Loading branch information
erikbern committed Jul 31, 2024
1 parent 0efbdef commit 5058f98
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 59 deletions.
52 changes: 4 additions & 48 deletions modal/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
)
from rich.spinner import Spinner
from rich.text import Text
from rich.tree import Tree

from modal_proto import api_pb2

Expand Down Expand Up @@ -75,24 +74,6 @@ def substep_completed(message: str) -> RenderableType:
return f"🔨 {message}"


class StatusRow:
def __init__(self, progress: "Optional[Tree]"):
self._spinner = None
self._step_node = None
if progress is not None:
self._spinner = step_progress()
self._step_node = progress.add(self._spinner)

def message(self, message):
if self._spinner is not None:
self._spinner.update(text=message)

def finish(self, message):
if self._step_node is not None:
self._spinner.update(text=message)
self._step_node.label = substep_completed(message)


def download_progress_bar() -> Progress:
"""
Returns a progress bar suitable for showing file download progress.
Expand Down Expand Up @@ -167,7 +148,6 @@ class OutputManager:
_app_page_url: Optional[str]
_show_image_logs: bool
_status_spinner_live: Optional[Live]
_tree: Optional[Tree]

def __init__(
self,
Expand All @@ -188,7 +168,6 @@ def __init__(
self._app_page_url = None
self._show_image_logs = False
self._status_spinner_live = None
self._tree = None

@classmethod
def disable(cls):
Expand Down Expand Up @@ -387,30 +366,6 @@ def show_status_spinner(self):
with self._status_spinner_live:
yield

@classmethod
@contextlib.contextmanager
def make_tree(cls):
if output_mgr := OutputManager.get():
tree = output_mgr._tree = Tree(step_progress("Creating objects..."), guide_style="gray50")
with output_mgr.make_live(tree):
try:
yield
finally:
output_mgr._tree = None
tree.label = step_completed("Created objects.")
output_mgr.print(tree)
else:
yield

@classmethod
def add_status_row(cls) -> "StatusRow":
# Return a status row to be used for object creation.
# If output isn't enabled, just create a hidden tree
if cls._instance and cls._instance._tree:
return StatusRow(cls._instance._tree)
else:
return StatusRow(Tree(""))


class ProgressHandler:
live: Live
Expand Down Expand Up @@ -703,11 +658,12 @@ class FunctionCreationStatus:
tag: str
response: Optional[api_pb2.FunctionCreateResponse] = None

def __init__(self, tag):
def __init__(self, resolver, tag):
self.resolver = resolver
self.tag = tag

def __enter__(self):
self.status_row = OutputManager.add_status_row()
self.status_row = self.resolver.add_status_row()
self.status_row.message(f"Creating function {self.tag}...")
return self

Expand Down Expand Up @@ -738,7 +694,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):

# Print custom domain in terminal
for custom_domain in self.response.function.custom_domain_info:
custom_domain_status_row = OutputManager.add_status_row()
custom_domain_status_row = self.resolver.add_status_row()
custom_domain_status_row.finish(
f"Custom domain for {self.tag} => [magenta underline]"
f"{custom_domain.url}[/magenta underline]{suffix}"
Expand Down
48 changes: 48 additions & 0 deletions modal/_resolver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright Modal Labs 2023
import asyncio
import contextlib
from asyncio import Future
from typing import TYPE_CHECKING, Dict, Hashable, List, Optional

Expand All @@ -11,9 +12,35 @@
from .exception import NotFoundError

if TYPE_CHECKING:
from rich.tree import Tree

from modal.object import _Object


class StatusRow:
def __init__(self, progress: "Optional[Tree]"):
from ._output import (
step_progress,
)

self._spinner = None
self._step_node = None
if progress is not None:
self._spinner = step_progress()
self._step_node = progress.add(self._spinner)

def message(self, message):
if self._spinner is not None:
self._spinner.update(text=message)

def finish(self, message):
from ._output import substep_completed

if self._step_node is not None:
self._spinner.update(text=message)
self._step_node.label = substep_completed(message)


class Resolver:
_local_uuid_to_future: Dict[str, Future]
_environment_name: Optional[str]
Expand All @@ -28,7 +55,12 @@ def __init__(
environment_name: Optional[str] = None,
app_id: Optional[str] = None,
):
from rich.tree import Tree

from ._output import step_progress

self._local_uuid_to_future = {}
self._tree = Tree(step_progress("Creating objects..."), guide_style="gray50")
self._client = client
self._app_id = app_id
self._environment_name = environment_name
Expand Down Expand Up @@ -123,3 +155,19 @@ def objects(self) -> List["_Object"]:
obj = fut.result()
unique_objects.setdefault(obj.object_id, obj)
return list(unique_objects.values())

@contextlib.contextmanager
def display(self):
# TODO(erikbern): get rid of this wrapper
from ._output import OutputManager, step_completed

if output_mgr := OutputManager.get():
with output_mgr.make_live(self._tree):
yield
self._tree.label = step_completed("Created objects.")
output_mgr.print(self._tree)
else:
yield

def add_status_row(self) -> StatusRow:
return StatusRow(self._tree)
4 changes: 2 additions & 2 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ async def _load(method_bound_function: "_Function", resolver: Resolver, existing
existing_function_id=existing_object_id or method_bound_function.object_id or "",
)
assert resolver.client.stub is not None # client should be connected when load is called
with FunctionCreationStatus(full_name) as function_creation_status:
with FunctionCreationStatus(resolver, full_name) as function_creation_status:
response = await resolver.client.stub.FunctionCreate(request)
method_bound_function._hydrate(
response.function_id,
Expand Down Expand Up @@ -715,7 +715,7 @@ async def _preload(self: _Function, resolver: Resolver, existing_object_id: Opti

async def _load(self: _Function, resolver: Resolver, existing_object_id: Optional[str]):
assert resolver.client and resolver.client.stub
with FunctionCreationStatus(tag) as function_creation_status:
with FunctionCreationStatus(resolver, tag) as function_creation_status:
if is_generator:
function_type = api_pb2.Function.FUNCTION_TYPE_GENERATOR
else:
Expand Down
3 changes: 1 addition & 2 deletions modal/mount.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from modal_proto import api_pb2
from modal_version import __version__

from ._output import OutputManager
from ._resolver import Resolver
from ._utils.async_utils import synchronize_api
from ._utils.blob_utils import FileUploadSpec, blob_upload_file, get_file_upload_spec_from_path
Expand Down Expand Up @@ -430,7 +429,7 @@ async def _load_mount(
accounted_hashes: set[str] = set()
message_label = _Mount._description(self._entries)
blob_upload_concurrency = asyncio.Semaphore(16) # Limit uploads of large files.
status_row = OutputManager.add_status_row()
status_row = resolver.add_status_row()

async def _put_file(file_spec: FileUploadSpec) -> api_pb2.MountFile:
nonlocal n_seen, n_finished, total_uploads, total_bytes
Expand Down
2 changes: 1 addition & 1 deletion modal/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ async def _create_all_objects(
environment_name=environment_name,
app_id=running_app.app_id,
)
with OutputManager.make_tree():
with resolver.display():
# Get current objects, and reset all objects
tag_to_object_id = running_app.tag_to_object_id
running_app.tag_to_object_id = {}
Expand Down
6 changes: 0 additions & 6 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import modal._serialization
from modal import __version__, config
from modal._container_io_manager import _ContainerIOManager
from modal._output import OutputManager
from modal._serialization import serialize_data_format
from modal._utils.async_utils import asyncify, synchronize_api
from modal._utils.grpc_testing import patch_mock_servicer
Expand Down Expand Up @@ -1552,11 +1551,6 @@ async def reset_default_client():
Client.set_env_client(None)


@pytest_asyncio.fixture(scope="function", autouse=True)
async def reset_output_manager():
OutputManager._instance = None


@pytest.fixture(name="mock_dir", scope="session")
def mock_dir_factory():
"""Sets up a temp dir with content as specified in a nested dict
Expand Down

0 comments on commit 5058f98

Please sign in to comment.