Skip to content

Commit

Permalink
add download progress bars to modal volume cli
Browse files Browse the repository at this point in the history
  • Loading branch information
kramstrom committed Jul 12, 2024
1 parent 13a2772 commit 5bd2df1
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 16 deletions.
22 changes: 16 additions & 6 deletions modal/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,18 +381,28 @@ def hide_status_spinner(self):

class ProgressHandler:
live: Live
_type: str
_spinner: Spinner
_overall_progress: Progress
_download_progress: Progress
_overall_progress_task_id: TaskID
_total_tasks: int
_completed_tasks: int

def __init__(self, console):
self._spinner = step_progress("Uploading file(s) to volume...")
def __init__(self, type: str, console: Console):
self._type = type

if self._type == "download":
title = "Downloading file(s) to local..."
elif self._type == "upload":
title = "Uploading file(s) to volume..."
else:
raise NotImplementedError(f"Progress handler of type: `{type}` not yet implemented")

self._spinner = step_progress(title)

self._overall_progress = Progress(
TextColumn("[bold white]Uploading files", justify="right"),
TextColumn(f"[bold white]{title}", justify="right"),
TimeElapsedColumn(),
BarColumn(bar_width=None),
TextColumn("[bold white]{task.description}"),
Expand All @@ -417,12 +427,12 @@ def __init__(self, console):
Group(self._spinner, self._overall_progress, self._download_progress), transient=True, refresh_per_second=4
)

self._overall_progress_task_id = self._overall_progress.add_task("upload files", start=True)
self._overall_progress_task_id = self._overall_progress.add_task(".", start=True)
self._total_tasks = 0
self._completed_tasks = 0

def _add_sub_task(self, name, size):
task_id = self._download_progress.add_task("upload", path=name, start=True, total=size)
task_id = self._download_progress.add_task(self._type, path=name, start=True, total=size)
self._total_tasks += 1
self._overall_progress.update(self._overall_progress_task_id, total=self._total_tasks)
return task_id
Expand All @@ -436,7 +446,7 @@ def _complete_sub_task(self, task_id):
self._overall_progress.update(
self._overall_progress_task_id,
advance=1,
description=f"({self._completed_tasks} out of {self._total_tasks} files uploaded)",
description=f"({self._completed_tasks} out of {self._total_tasks} files completed)",
)
if self._completed_tasks == self._total_tasks:
self._overall_progress.remove_task(self._overall_progress_task_id)
Expand Down
4 changes: 2 additions & 2 deletions modal/_utils/blob_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ async def safe_read():
chunk = await safe_read()
while chunk and self.remaining_bytes() > 0:
await writer.write(chunk)
self.progress_report_cb(len(chunk))
self.progress_report_cb(advance=len(chunk))
chunk = await safe_read()
if chunk:
await writer.write(chunk)
self.progress_report_cb(len(chunk))
self.progress_report_cb(advance=len(chunk))

self.progress_report_cb(complete=True)

Expand Down
13 changes: 11 additions & 2 deletions modal/cli/_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
import shutil
import sys
from pathlib import Path, PurePosixPath
from typing import AsyncIterator, Optional, Tuple, Union
from typing import AsyncIterator, Callable, Optional, Tuple, Union

from click import UsageError

from modal._utils.async_utils import TaskContext
from modal.network_file_system import _NetworkFileSystem
from modal.volume import FileEntry, FileEntryType, _Volume

from .config import logger

PIPE_PATH = Path("-")


Expand All @@ -20,6 +22,7 @@ async def _volume_download(
remote_path: str,
local_destination: Path,
overwrite: bool,
progress_cb: Callable,
):
is_pipe = local_destination == PIPE_PATH

Expand Down Expand Up @@ -66,16 +69,22 @@ async def consumer():
try:
if is_pipe:
if entry.type == FileEntryType.FILE:
progress_task_id = progress_cb(name=entry.path, size=entry.size)
async for chunk in volume.read_file(entry.path):
sys.stdout.buffer.write(chunk)
progress_cb(task_id=progress_task_id, advance=len(chunk))
progress_cb(task_id=progress_task_id, complete=True)
else:
if entry.type == FileEntryType.FILE:
progress_task_id = progress_cb(name=entry.path, size=entry.size)
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("wb") as fp:
b = 0
async for chunk in volume.read_file(entry.path):
b += fp.write(chunk)
print(f"Wrote {b} bytes to {output_path}", file=sys.stderr)
progress_cb(task_id=progress_task_id, advance=len(chunk))
logger.debug(f"Wrote {b} bytes to {output_path}", file=sys.stderr)
progress_cb(task_id=progress_task_id, complete=True)
elif entry.type == FileEntryType.DIRECTORY:
output_path.mkdir(parents=True, exist_ok=True)
finally:
Expand Down
8 changes: 6 additions & 2 deletions modal/cli/network_file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import modal
from modal._location import display_location
from modal._output import step_completed, step_progress
from modal._output import ProgressHandler, step_completed, step_progress
from modal._utils.async_utils import synchronizer
from modal._utils.grpc_utils import retry_transient_errors
from modal.cli._download import _volume_download
Expand Down Expand Up @@ -189,7 +189,11 @@ async def get(
ensure_env(env)
destination = Path(local_destination)
volume = await _volume_from_name(volume_name)
await _volume_download(volume, remote_path, destination, force)
console = Console()
progress_handler = ProgressHandler(type="download", console=console)
with progress_handler.live:
await _volume_download(volume, remote_path, destination, force, progress_cb=progress_handler.progress)
console.print(step_completed("Finished downloading files to local!"))


@nfs_cli.command(
Expand Down
8 changes: 6 additions & 2 deletions modal/cli/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ async def get(
ensure_env(env)
destination = Path(local_destination)
volume = await _Volume.lookup(volume_name, environment_name=env)
await _volume_download(volume, remote_path, destination, force)
console = Console()
progress_handler = ProgressHandler(type="download", console=console)
with progress_handler.live:
await _volume_download(volume, remote_path, destination, force, progress_cb=progress_handler.progress)
console.print(step_completed("Finished downloading files to local!"))


@volume_cli.command(
Expand Down Expand Up @@ -194,7 +198,7 @@ async def put(
if remote_path.endswith("/"):
remote_path = remote_path + os.path.basename(local_path)
console = Console()
progress_handler = ProgressHandler(console=console)
progress_handler = ProgressHandler(type="upload", console=console)

if Path(local_path).is_dir():
with progress_handler.live:
Expand Down
4 changes: 2 additions & 2 deletions modal/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,9 +677,9 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile:
if not response.exists:
raise VolumeUploadTimeoutError(f"Uploading of {file_spec.source_description} timed out")
elif not file_spec.use_blob:
self._progress_cb(progress_task_id, complete=True)
self._progress_cb(task_id=progress_task_id, complete=True)
else:
self._progress_cb(progress_task_id, complete=True)
self._progress_cb(task_id=progress_task_id, complete=True)
return api_pb2.MountFile(
filename=remote_filename,
sha256_hex=file_spec.sha256_hex,
Expand Down

0 comments on commit 5bd2df1

Please sign in to comment.