Skip to content

Commit

Permalink
[MOD-3267] progress bar for volume/nfs put/get (#1969)
Browse files Browse the repository at this point in the history
  • Loading branch information
kramstrom committed Jul 16, 2024
1 parent 761bc65 commit 852e49e
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 34 deletions.
112 changes: 112 additions & 0 deletions modal/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,118 @@ def hide_status_spinner(self):
self._status_spinner_live.stop()


class ProgressHandler:
live: Live
_type: str
_spinner: Spinner
_overall_progress: Progress
_download_progress: Progress
_overall_progress_task_id: TaskID
_total_tasks: int
_completed_tasks: int

def __init__(self, type: str, console: Console):
self._type = type

if self._type == "download":
title = "Downloading file(s) to local..."
elif self._type == "upload":
title = "Uploading file(s) to volume..."
else:
raise NotImplementedError(f"Progress handler of type: `{type}` not yet implemented")

self._spinner = step_progress(title)

self._overall_progress = Progress(
TextColumn(f"[bold white]{title}", justify="right"),
TimeElapsedColumn(),
BarColumn(bar_width=None),
TextColumn("[bold white]{task.description}"),
transient=True,
console=console,
)
self._download_progress = Progress(
TextColumn("[bold white]{task.fields[path]}", justify="right"),
BarColumn(bar_width=None),
"[progress.percentage]{task.percentage:>3.1f}%",
"•",
DownloadColumn(),
"•",
TransferSpeedColumn(),
"•",
TimeRemainingColumn(),
transient=True,
console=console,
)

self.live = Live(
Group(self._spinner, self._overall_progress, self._download_progress), transient=True, refresh_per_second=4
)

self._overall_progress_task_id = self._overall_progress.add_task(".", start=True)
self._total_tasks = 0
self._completed_tasks = 0

def _add_sub_task(self, name: str, size: float) -> TaskID:
task_id = self._download_progress.add_task(self._type, path=name, start=True, total=size)
self._total_tasks += 1
self._overall_progress.update(self._overall_progress_task_id, total=self._total_tasks)
return task_id

def _reset_sub_task(self, task_id: TaskID):
self._download_progress.reset(task_id)

def _complete_progress(self):
# TODO: we could probably implement some callback progression from the server
# to get progress reports for the post processing too
# so we don't have to just spin here
self._overall_progress.remove_task(self._overall_progress_task_id)
self._spinner.update(text="Post processing...")

def _complete_sub_task(self, task_id: TaskID):
self._completed_tasks += 1
self._download_progress.remove_task(task_id)
self._overall_progress.update(
self._overall_progress_task_id,
advance=1,
description=f"({self._completed_tasks} out of {self._total_tasks} files completed)",
)

def _advance_sub_task(self, task_id: TaskID, advance: float):
self._download_progress.update(task_id, advance=advance)

def progress(
self,
task_id: Optional[TaskID] = None,
advance: Optional[float] = None,
name: Optional[str] = None,
size: Optional[float] = None,
reset: Optional[bool] = False,
complete: Optional[bool] = False,
) -> Optional[TaskID]:
if task_id is not None:
if reset:
return self._reset_sub_task(task_id)
elif complete:
return self._complete_sub_task(task_id)
elif advance is not None:
return self._advance_sub_task(task_id, advance)
elif name is not None and size is not None:
return self._add_sub_task(name, size)
elif complete:
return self._complete_progress()

raise NotImplementedError(
"Unknown action to take with args: "
+ f"name={name} "
+ f"size={size} "
+ f"task_id={task_id} "
+ f"advance={advance} "
+ f"reset={reset} "
+ f"complete={complete} "
)


async def stream_pty_shell_input(client: _Client, exec_id: str, finish_event: asyncio.Event):
"""
Streams stdin to the given exec id until finish_event is triggered
Expand Down
27 changes: 23 additions & 4 deletions modal/_utils/blob_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
segment_start: int,
segment_length: int,
chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
progress_report_cb: Optional[Callable] = None,
):
# not thread safe constructor!
super().__init__(bytes_io)
Expand All @@ -63,6 +64,7 @@ def __init__(
self._value.seek(self.initial_seek_pos + segment_start)
assert self.segment_length <= super().size
self.chunk_size = chunk_size
self.progress_report_cb = progress_report_cb or (lambda *_, **__: None)
self.reset_state()

def reset_state(self):
Expand All @@ -74,6 +76,12 @@ def reset_state(self):
def reset_on_error(self):
try:
yield
except Exception as exc:
try:
self.progress_report_cb(reset=True)
except Exception as cb_exc:
raise cb_exc from exc
raise exc
finally:
self.reset_state()

Expand All @@ -100,9 +108,13 @@ async def safe_read():
chunk = await safe_read()
while chunk and self.remaining_bytes() > 0:
await writer.write(chunk)
self.progress_report_cb(advance=len(chunk))
chunk = await safe_read()
if chunk:
await writer.write(chunk)
self.progress_report_cb(advance=len(chunk))

self.progress_report_cb(complete=True)

def remaining_bytes(self):
return self.segment_length - self.num_bytes_read
Expand Down Expand Up @@ -165,6 +177,7 @@ async def perform_multipart_upload(
part_urls: List[str],
completion_url: str,
upload_chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE,
progress_report_cb: Optional[Callable] = None,
):
upload_coros = []
file_offset = 0
Expand All @@ -187,6 +200,7 @@ async def perform_multipart_upload(
segment_start=file_offset,
segment_length=part_length_bytes,
chunk_size=upload_chunk_size,
progress_report_cb=progress_report_cb,
)
upload_coros.append(_upload_to_s3_url(part_url, payload=part_payload, content_type=None))
num_bytes_left -= part_length_bytes
Expand Down Expand Up @@ -230,7 +244,9 @@ def get_content_length(data: BinaryIO):
return content_length - pos


async def _blob_upload(upload_hashes: UploadHashes, data: Union[bytes, BinaryIO], stub) -> str:
async def _blob_upload(
upload_hashes: UploadHashes, data: Union[bytes, BinaryIO], stub, progress_report_cb: Optional[Callable] = None
) -> str:
if isinstance(data, bytes):
data = io.BytesIO(data)

Expand All @@ -253,9 +269,12 @@ async def _blob_upload(upload_hashes: UploadHashes, data: Union[bytes, BinaryIO]
part_urls=resp.multipart.upload_urls,
completion_url=resp.multipart.completion_url,
upload_chunk_size=DEFAULT_SEGMENT_CHUNK_SIZE,
progress_report_cb=progress_report_cb,
)
else:
payload = BytesIOSegmentPayload(data, segment_start=0, segment_length=content_length)
payload = BytesIOSegmentPayload(
data, segment_start=0, segment_length=content_length, progress_report_cb=progress_report_cb
)
await _upload_to_s3_url(
resp.upload_url,
payload,
Expand All @@ -274,9 +293,9 @@ async def blob_upload(payload: bytes, stub) -> str:
return await _blob_upload(upload_hashes, payload, stub)


async def blob_upload_file(file_obj: BinaryIO, stub) -> str:
async def blob_upload_file(file_obj: BinaryIO, stub, progress_report_cb: Optional[Callable] = None) -> str:
upload_hashes = get_upload_hashes(file_obj)
return await _blob_upload(upload_hashes, file_obj, stub)
return await _blob_upload(upload_hashes, file_obj, stub, progress_report_cb)


@retry(n_attempts=5, base_delay=0.1, timeout=None)
Expand Down
13 changes: 11 additions & 2 deletions modal/cli/_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import shutil
import sys
from pathlib import Path, PurePosixPath
from typing import AsyncIterator, Optional, Tuple, Union
from typing import AsyncIterator, Callable, Optional, Tuple, Union

from click import UsageError

from modal._utils.async_utils import TaskContext
from modal.config import logger
from modal.network_file_system import _NetworkFileSystem
from modal.volume import FileEntry, FileEntryType, _Volume

Expand All @@ -20,6 +21,7 @@ async def _volume_download(
remote_path: str,
local_destination: Path,
overwrite: bool,
progress_cb: Callable,
):
is_pipe = local_destination == PIPE_PATH

Expand Down Expand Up @@ -66,21 +68,28 @@ async def consumer():
try:
if is_pipe:
if entry.type == FileEntryType.FILE:
progress_task_id = progress_cb(name=entry.path, size=entry.size)
async for chunk in volume.read_file(entry.path):
sys.stdout.buffer.write(chunk)
progress_cb(task_id=progress_task_id, advance=len(chunk))
progress_cb(task_id=progress_task_id, complete=True)
else:
if entry.type == FileEntryType.FILE:
progress_task_id = progress_cb(name=entry.path, size=entry.size)
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("wb") as fp:
b = 0
async for chunk in volume.read_file(entry.path):
b += fp.write(chunk)
print(f"Wrote {b} bytes to {output_path}", file=sys.stderr)
progress_cb(task_id=progress_task_id, advance=len(chunk))
logger.debug(f"Wrote {b} bytes to {output_path}")
progress_cb(task_id=progress_task_id, complete=True)
elif entry.type == FileEntryType.DIRECTORY:
output_path.mkdir(parents=True, exist_ok=True)
finally:
q.task_done()

consumers = [consumer() for _ in range(num_consumers)]
await TaskContext.gather(producer(), *consumers)
progress_cb(complete=True)
sys.stdout.flush()
23 changes: 14 additions & 9 deletions modal/cli/network_file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
from click import UsageError
from grpclib import GRPCError, Status
from rich.console import Console
from rich.live import Live
from rich.syntax import Syntax
from rich.table import Table
from typer import Typer

import modal
from modal._location import display_location
from modal._output import step_completed, step_progress
from modal._output import ProgressHandler, step_completed
from modal._utils.async_utils import synchronizer
from modal._utils.grpc_utils import retry_transient_errors
from modal.cli._download import _volume_download
Expand Down Expand Up @@ -143,17 +142,19 @@ async def put(
console = Console()

if Path(local_path).is_dir():
spinner = step_progress(f"Uploading directory '{local_path}' to '{remote_path}'...")
with Live(spinner, console=console):
await volume.add_local_dir(local_path, remote_path)
progress_handler = ProgressHandler(type="upload", console=console)
with progress_handler.live:
await volume.add_local_dir(local_path, remote_path, progress_cb=progress_handler.progress)
progress_handler.progress(complete=True)
console.print(step_completed(f"Uploaded directory '{local_path}' to '{remote_path}'"))

elif "*" in local_path:
raise UsageError("Glob uploads are currently not supported")
else:
spinner = step_progress(f"Uploading file '{local_path}' to '{remote_path}'...")
with Live(spinner, console=console):
written_bytes = await volume.add_local_file(local_path, remote_path)
progress_handler = ProgressHandler(type="upload", console=console)
with progress_handler.live:
written_bytes = await volume.add_local_file(local_path, remote_path, progress_cb=progress_handler.progress)
progress_handler.progress(complete=True)
console.print(
step_completed(f"Uploaded file '{local_path}' to '{remote_path}' ({written_bytes} bytes written)")
)
Expand Down Expand Up @@ -189,7 +190,11 @@ async def get(
ensure_env(env)
destination = Path(local_destination)
volume = await _volume_from_name(volume_name)
await _volume_download(volume, remote_path, destination, force)
console = Console()
progress_handler = ProgressHandler(type="download", console=console)
with progress_handler.live:
await _volume_download(volume, remote_path, destination, force, progress_cb=progress_handler.progress)
console.print(step_completed("Finished downloading files to local!"))


@nfs_cli.command(
Expand Down
25 changes: 16 additions & 9 deletions modal/cli/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@
from click import UsageError
from grpclib import GRPCError, Status
from rich.console import Console
from rich.live import Live
from rich.syntax import Syntax
from typer import Argument, Option, Typer

import modal
from modal._output import step_completed, step_progress
from modal._output import ProgressHandler, step_completed
from modal._utils.async_utils import synchronizer
from modal._utils.grpc_utils import retry_transient_errors
from modal.cli._download import _volume_download
Expand Down Expand Up @@ -97,7 +96,11 @@ async def get(
ensure_env(env)
destination = Path(local_destination)
volume = await _Volume.lookup(volume_name, environment_name=env)
await _volume_download(volume, remote_path, destination, force)
console = Console()
progress_handler = ProgressHandler(type="download", console=console)
with progress_handler.live:
await _volume_download(volume, remote_path, destination, force, progress_cb=progress_handler.progress)
console.print(step_completed("Finished downloading files to local!"))


@volume_cli.command(
Expand Down Expand Up @@ -195,24 +198,28 @@ async def put(
if remote_path.endswith("/"):
remote_path = remote_path + os.path.basename(local_path)
console = Console()
progress_handler = ProgressHandler(type="upload", console=console)

if Path(local_path).is_dir():
spinner = step_progress(f"Uploading directory '{local_path}' to '{remote_path}'...")
with Live(spinner, console=console):
with progress_handler.live:
try:
async with _VolumeUploadContextManager(vol.object_id, vol._client, force=force) as batch:
async with _VolumeUploadContextManager(
vol.object_id, vol._client, progress_cb=progress_handler.progress, force=force
) as batch:
batch.put_directory(local_path, remote_path)
except FileExistsError as exc:
raise UsageError(str(exc))
console.print(step_completed(f"Uploaded directory '{local_path}' to '{remote_path}'"))
elif "*" in local_path:
raise UsageError("Glob uploads are currently not supported")
else:
spinner = step_progress(f"Uploading file '{local_path}' to '{remote_path}'...")
with Live(spinner, console=console):
with progress_handler.live:
try:
async with _VolumeUploadContextManager(vol.object_id, vol._client, force=force) as batch:
async with _VolumeUploadContextManager(
vol.object_id, vol._client, progress_cb=progress_handler.progress, force=force
) as batch:
batch.put_file(local_path, remote_path)

except FileExistsError as exc:
raise UsageError(str(exc))
console.print(step_completed(f"Uploaded file '{local_path}' to '{remote_path}'"))
Expand Down
Loading

0 comments on commit 852e49e

Please sign in to comment.