diff --git a/modal/_output.py b/modal/_output.py index a24a8a05a..784442361 100644 --- a/modal/_output.py +++ b/modal/_output.py @@ -387,6 +387,118 @@ def hide_status_spinner(self): self._status_spinner_live.stop() +class ProgressHandler: + live: Live + _type: str + _spinner: Spinner + _overall_progress: Progress + _download_progress: Progress + _overall_progress_task_id: TaskID + _total_tasks: int + _completed_tasks: int + + def __init__(self, type: str, console: Console): + self._type = type + + if self._type == "download": + title = "Downloading file(s) to local..." + elif self._type == "upload": + title = "Uploading file(s) to volume..." + else: + raise NotImplementedError(f"Progress handler of type: `{type}` not yet implemented") + + self._spinner = step_progress(title) + + self._overall_progress = Progress( + TextColumn(f"[bold white]{title}", justify="right"), + TimeElapsedColumn(), + BarColumn(bar_width=None), + TextColumn("[bold white]{task.description}"), + transient=True, + console=console, + ) + self._download_progress = Progress( + TextColumn("[bold white]{task.fields[path]}", justify="right"), + BarColumn(bar_width=None), + "[progress.percentage]{task.percentage:>3.1f}%", + "•", + DownloadColumn(), + "•", + TransferSpeedColumn(), + "•", + TimeRemainingColumn(), + transient=True, + console=console, + ) + + self.live = Live( + Group(self._spinner, self._overall_progress, self._download_progress), transient=True, refresh_per_second=4 + ) + + self._overall_progress_task_id = self._overall_progress.add_task(".", start=True) + self._total_tasks = 0 + self._completed_tasks = 0 + + def _add_sub_task(self, name: str, size: float) -> TaskID: + task_id = self._download_progress.add_task(self._type, path=name, start=True, total=size) + self._total_tasks += 1 + self._overall_progress.update(self._overall_progress_task_id, total=self._total_tasks) + return task_id + + def _reset_sub_task(self, task_id: TaskID): + self._download_progress.reset(task_id) + + def _complete_progress(self): + # TODO: we could probably implement some callback progression from the server + # to get progress reports for the post processing too + # so we don't have to just spin here + self._overall_progress.remove_task(self._overall_progress_task_id) + self._spinner.update(text="Post processing...") + + def _complete_sub_task(self, task_id: TaskID): + self._completed_tasks += 1 + self._download_progress.remove_task(task_id) + self._overall_progress.update( + self._overall_progress_task_id, + advance=1, + description=f"({self._completed_tasks} out of {self._total_tasks} files completed)", + ) + + def _advance_sub_task(self, task_id: TaskID, advance: float): + self._download_progress.update(task_id, advance=advance) + + def progress( + self, + task_id: Optional[TaskID] = None, + advance: Optional[float] = None, + name: Optional[str] = None, + size: Optional[float] = None, + reset: Optional[bool] = False, + complete: Optional[bool] = False, + ) -> Optional[TaskID]: + if task_id is not None: + if reset: + return self._reset_sub_task(task_id) + elif complete: + return self._complete_sub_task(task_id) + elif advance is not None: + return self._advance_sub_task(task_id, advance) + elif name is not None and size is not None: + return self._add_sub_task(name, size) + elif complete: + return self._complete_progress() + + raise NotImplementedError( + "Unknown action to take with args: " + + f"name={name} " + + f"size={size} " + + f"task_id={task_id} " + + f"advance={advance} " + + f"reset={reset} " + + f"complete={complete} " + ) + + async def stream_pty_shell_input(client: _Client, exec_id: str, finish_event: asyncio.Event): """ Streams stdin to the given exec id until finish_event is triggered diff --git a/modal/_utils/blob_utils.py b/modal/_utils/blob_utils.py index 56f32f2d0..0c02f46f6 100644 --- a/modal/_utils/blob_utils.py +++ b/modal/_utils/blob_utils.py @@ -53,6 +53,7 @@ def __init__( segment_start: int, segment_length: int, chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE, + progress_report_cb: Optional[Callable] = None, ): # not thread safe constructor! super().__init__(bytes_io) @@ -63,6 +64,7 @@ def __init__( self._value.seek(self.initial_seek_pos + segment_start) assert self.segment_length <= super().size self.chunk_size = chunk_size + self.progress_report_cb = progress_report_cb or (lambda *_, **__: None) self.reset_state() def reset_state(self): @@ -74,6 +76,12 @@ def reset_state(self): def reset_on_error(self): try: yield + except Exception as exc: + try: + self.progress_report_cb(reset=True) + except Exception as cb_exc: + raise cb_exc from exc + raise exc finally: self.reset_state() @@ -100,9 +108,13 @@ async def safe_read(): chunk = await safe_read() while chunk and self.remaining_bytes() > 0: await writer.write(chunk) + self.progress_report_cb(advance=len(chunk)) chunk = await safe_read() if chunk: await writer.write(chunk) + self.progress_report_cb(advance=len(chunk)) + + self.progress_report_cb(complete=True) def remaining_bytes(self): return self.segment_length - self.num_bytes_read @@ -165,6 +177,7 @@ async def perform_multipart_upload( part_urls: List[str], completion_url: str, upload_chunk_size: int = DEFAULT_SEGMENT_CHUNK_SIZE, + progress_report_cb: Optional[Callable] = None, ): upload_coros = [] file_offset = 0 @@ -187,6 +200,7 @@ async def perform_multipart_upload( segment_start=file_offset, segment_length=part_length_bytes, chunk_size=upload_chunk_size, + progress_report_cb=progress_report_cb, ) upload_coros.append(_upload_to_s3_url(part_url, payload=part_payload, content_type=None)) num_bytes_left -= part_length_bytes @@ -230,7 +244,9 @@ def get_content_length(data: BinaryIO): return content_length - pos -async def _blob_upload(upload_hashes: UploadHashes, data: Union[bytes, BinaryIO], stub) -> str: +async def _blob_upload( + upload_hashes: UploadHashes, data: Union[bytes, BinaryIO], stub, progress_report_cb: Optional[Callable] = None +) -> str: if isinstance(data, bytes): data = io.BytesIO(data) @@ -253,9 +269,12 @@ async def _blob_upload(upload_hashes: UploadHashes, data: Union[bytes, BinaryIO] part_urls=resp.multipart.upload_urls, completion_url=resp.multipart.completion_url, upload_chunk_size=DEFAULT_SEGMENT_CHUNK_SIZE, + progress_report_cb=progress_report_cb, ) else: - payload = BytesIOSegmentPayload(data, segment_start=0, segment_length=content_length) + payload = BytesIOSegmentPayload( + data, segment_start=0, segment_length=content_length, progress_report_cb=progress_report_cb + ) await _upload_to_s3_url( resp.upload_url, payload, @@ -274,9 +293,9 @@ async def blob_upload(payload: bytes, stub) -> str: return await _blob_upload(upload_hashes, payload, stub) -async def blob_upload_file(file_obj: BinaryIO, stub) -> str: +async def blob_upload_file(file_obj: BinaryIO, stub, progress_report_cb: Optional[Callable] = None) -> str: upload_hashes = get_upload_hashes(file_obj) - return await _blob_upload(upload_hashes, file_obj, stub) + return await _blob_upload(upload_hashes, file_obj, stub, progress_report_cb) @retry(n_attempts=5, base_delay=0.1, timeout=None) diff --git a/modal/cli/_download.py b/modal/cli/_download.py index 7f9f114c5..c410c6858 100644 --- a/modal/cli/_download.py +++ b/modal/cli/_download.py @@ -4,11 +4,12 @@ import shutil import sys from pathlib import Path, PurePosixPath -from typing import AsyncIterator, Optional, Tuple, Union +from typing import AsyncIterator, Callable, Optional, Tuple, Union from click import UsageError from modal._utils.async_utils import TaskContext +from modal.config import logger from modal.network_file_system import _NetworkFileSystem from modal.volume import FileEntry, FileEntryType, _Volume @@ -20,6 +21,7 @@ async def _volume_download( remote_path: str, local_destination: Path, overwrite: bool, + progress_cb: Callable, ): is_pipe = local_destination == PIPE_PATH @@ -66,16 +68,22 @@ async def consumer(): try: if is_pipe: if entry.type == FileEntryType.FILE: + progress_task_id = progress_cb(name=entry.path, size=entry.size) async for chunk in volume.read_file(entry.path): sys.stdout.buffer.write(chunk) + progress_cb(task_id=progress_task_id, advance=len(chunk)) + progress_cb(task_id=progress_task_id, complete=True) else: if entry.type == FileEntryType.FILE: + progress_task_id = progress_cb(name=entry.path, size=entry.size) output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("wb") as fp: b = 0 async for chunk in volume.read_file(entry.path): b += fp.write(chunk) - print(f"Wrote {b} bytes to {output_path}", file=sys.stderr) + progress_cb(task_id=progress_task_id, advance=len(chunk)) + logger.debug(f"Wrote {b} bytes to {output_path}") + progress_cb(task_id=progress_task_id, complete=True) elif entry.type == FileEntryType.DIRECTORY: output_path.mkdir(parents=True, exist_ok=True) finally: @@ -83,4 +91,5 @@ async def consumer(): consumers = [consumer() for _ in range(num_consumers)] await TaskContext.gather(producer(), *consumers) + progress_cb(complete=True) sys.stdout.flush() diff --git a/modal/cli/network_file_system.py b/modal/cli/network_file_system.py index e548d2efb..d82eb24af 100644 --- a/modal/cli/network_file_system.py +++ b/modal/cli/network_file_system.py @@ -8,14 +8,13 @@ from click import UsageError from grpclib import GRPCError, Status from rich.console import Console -from rich.live import Live from rich.syntax import Syntax from rich.table import Table from typer import Typer import modal from modal._location import display_location -from modal._output import step_completed, step_progress +from modal._output import ProgressHandler, step_completed from modal._utils.async_utils import synchronizer from modal._utils.grpc_utils import retry_transient_errors from modal.cli._download import _volume_download @@ -143,17 +142,19 @@ async def put( console = Console() if Path(local_path).is_dir(): - spinner = step_progress(f"Uploading directory '{local_path}' to '{remote_path}'...") - with Live(spinner, console=console): - await volume.add_local_dir(local_path, remote_path) + progress_handler = ProgressHandler(type="upload", console=console) + with progress_handler.live: + await volume.add_local_dir(local_path, remote_path, progress_cb=progress_handler.progress) + progress_handler.progress(complete=True) console.print(step_completed(f"Uploaded directory '{local_path}' to '{remote_path}'")) elif "*" in local_path: raise UsageError("Glob uploads are currently not supported") else: - spinner = step_progress(f"Uploading file '{local_path}' to '{remote_path}'...") - with Live(spinner, console=console): - written_bytes = await volume.add_local_file(local_path, remote_path) + progress_handler = ProgressHandler(type="upload", console=console) + with progress_handler.live: + written_bytes = await volume.add_local_file(local_path, remote_path, progress_cb=progress_handler.progress) + progress_handler.progress(complete=True) console.print( step_completed(f"Uploaded file '{local_path}' to '{remote_path}' ({written_bytes} bytes written)") ) @@ -189,7 +190,11 @@ async def get( ensure_env(env) destination = Path(local_destination) volume = await _volume_from_name(volume_name) - await _volume_download(volume, remote_path, destination, force) + console = Console() + progress_handler = ProgressHandler(type="download", console=console) + with progress_handler.live: + await _volume_download(volume, remote_path, destination, force, progress_cb=progress_handler.progress) + console.print(step_completed("Finished downloading files to local!")) @nfs_cli.command( diff --git a/modal/cli/volume.py b/modal/cli/volume.py index 006926487..e91e6fed2 100644 --- a/modal/cli/volume.py +++ b/modal/cli/volume.py @@ -8,12 +8,11 @@ from click import UsageError from grpclib import GRPCError, Status from rich.console import Console -from rich.live import Live from rich.syntax import Syntax from typer import Argument, Option, Typer import modal -from modal._output import step_completed, step_progress +from modal._output import ProgressHandler, step_completed from modal._utils.async_utils import synchronizer from modal._utils.grpc_utils import retry_transient_errors from modal.cli._download import _volume_download @@ -97,7 +96,11 @@ async def get( ensure_env(env) destination = Path(local_destination) volume = await _Volume.lookup(volume_name, environment_name=env) - await _volume_download(volume, remote_path, destination, force) + console = Console() + progress_handler = ProgressHandler(type="download", console=console) + with progress_handler.live: + await _volume_download(volume, remote_path, destination, force, progress_cb=progress_handler.progress) + console.print(step_completed("Finished downloading files to local!")) @volume_cli.command( @@ -195,12 +198,14 @@ async def put( if remote_path.endswith("/"): remote_path = remote_path + os.path.basename(local_path) console = Console() + progress_handler = ProgressHandler(type="upload", console=console) if Path(local_path).is_dir(): - spinner = step_progress(f"Uploading directory '{local_path}' to '{remote_path}'...") - with Live(spinner, console=console): + with progress_handler.live: try: - async with _VolumeUploadContextManager(vol.object_id, vol._client, force=force) as batch: + async with _VolumeUploadContextManager( + vol.object_id, vol._client, progress_cb=progress_handler.progress, force=force + ) as batch: batch.put_directory(local_path, remote_path) except FileExistsError as exc: raise UsageError(str(exc)) @@ -208,11 +213,13 @@ async def put( elif "*" in local_path: raise UsageError("Glob uploads are currently not supported") else: - spinner = step_progress(f"Uploading file '{local_path}' to '{remote_path}'...") - with Live(spinner, console=console): + with progress_handler.live: try: - async with _VolumeUploadContextManager(vol.object_id, vol._client, force=force) as batch: + async with _VolumeUploadContextManager( + vol.object_id, vol._client, progress_cb=progress_handler.progress, force=force + ) as batch: batch.put_file(local_path, remote_path) + except FileExistsError as exc: raise UsageError(str(exc)) console.print(step_completed(f"Uploaded file '{local_path}' to '{remote_path}'")) diff --git a/modal/network_file_system.py b/modal/network_file_system.py index d0fdd33ea..46999224f 100644 --- a/modal/network_file_system.py +++ b/modal/network_file_system.py @@ -1,8 +1,9 @@ # Copyright Modal Labs 2023 +import functools import os import time from pathlib import Path, PurePosixPath -from typing import AsyncIterator, BinaryIO, List, Optional, Tuple, Type, Union +from typing import AsyncIterator, BinaryIO, Callable, List, Optional, Tuple, Type, Union import aiostream from grpclib import GRPCError, Status @@ -234,7 +235,7 @@ async def create_deployed( return resp.shared_volume_id @live_method - async def write_file(self, remote_path: str, fp: BinaryIO) -> int: + async def write_file(self, remote_path: str, fp: BinaryIO, progress_cb: Optional[Callable] = None) -> int: """Write from a file object to a path on the network file system, atomically. Will create any needed parent directories automatically. @@ -242,12 +243,17 @@ async def write_file(self, remote_path: str, fp: BinaryIO) -> int: If remote_path ends with `/` it's assumed to be a directory and the file will be uploaded with its current name to that directory. """ + progress_cb = progress_cb or (lambda *_, **__: None) + sha_hash = get_sha256_hex(fp) fp.seek(0, os.SEEK_END) data_size = fp.tell() fp.seek(0) if data_size > LARGE_FILE_LIMIT: - blob_id = await blob_upload_file(fp, self._client.stub) + progress_task_id = progress_cb(name=remote_path, size=data_size) + blob_id = await blob_upload_file( + fp, self._client.stub, progress_report_cb=functools.partial(progress_cb, progress_task_id) + ) req = api_pb2.SharedVolumePutFileRequest( shared_volume_id=self.object_id, path=remote_path, @@ -301,7 +307,10 @@ async def iterdir(self, path: str) -> AsyncIterator[FileEntry]: @live_method async def add_local_file( - self, local_path: Union[Path, str], remote_path: Optional[Union[str, PurePosixPath, None]] = None + self, + local_path: Union[Path, str], + remote_path: Optional[Union[str, PurePosixPath, None]] = None, + progress_cb: Optional[Callable] = None, ): local_path = Path(local_path) if remote_path is None: @@ -310,13 +319,14 @@ async def add_local_file( remote_path = PurePosixPath(remote_path).as_posix() with local_path.open("rb") as local_file: - return await self.write_file(remote_path, local_file) + return await self.write_file(remote_path, local_file, progress_cb=progress_cb) @live_method async def add_local_dir( self, local_path: Union[Path, str], remote_path: Optional[Union[str, PurePosixPath, None]] = None, + progress_cb: Optional[Callable] = None, ): _local_path = Path(local_path) if remote_path is None: @@ -336,7 +346,7 @@ def gen_transfers(): transfer_paths = aiostream.stream.iterate(gen_transfers()) await aiostream.stream.map( transfer_paths, - aiostream.async_(lambda paths: self.add_local_file(paths[0], paths[1])), + aiostream.async_(lambda paths: self.add_local_file(paths[0], paths[1], progress_cb)), task_limit=20, ) diff --git a/modal/volume.py b/modal/volume.py index 55bde4bd6..4857ae5aa 100644 --- a/modal/volume.py +++ b/modal/volume.py @@ -2,6 +2,7 @@ import asyncio import concurrent.futures import enum +import functools import os import platform import re @@ -546,13 +547,15 @@ class _VolumeUploadContextManager: _volume_id: str _client: _Client _force: bool + progress_cb: Callable _upload_generators: List[Generator[Callable[[], FileUploadSpec], None, None]] - def __init__(self, volume_id: str, client: _Client, force: bool = False): + def __init__(self, volume_id: str, client: _Client, progress_cb: Optional[Callable] = None, force: bool = False): """mdmd:hidden""" self._volume_id = volume_id self._client = client self._upload_generators = [] + self._progress_cb = progress_cb or (lambda *_, **__: None) self._force = force async def __aenter__(self): @@ -579,6 +582,7 @@ async def gen_file_upload_specs() -> AsyncGenerator[FileUploadSpec, None]: # Upload files uploads_stream = aiostream.stream.map(files_stream, self._upload_file, task_limit=20) files: List[api_pb2.MountFile] = await aiostream.stream.list(uploads_stream) + self._progress_cb(complete=True) request = api_pb2.VolumePutFilesRequest( volume_id=self._volume_id, @@ -644,7 +648,7 @@ def gen(): async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile: remote_filename = file_spec.mount_filename - + progress_task_id = self._progress_cb(name=remote_filename, size=file_spec.size) request = api_pb2.MountPutFileRequest(sha256_hex=file_spec.sha256_hex) response = await retry_transient_errors(self._client.stub.MountPutFile, request, base_delay=1) @@ -653,7 +657,9 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile: if file_spec.use_blob: logger.debug(f"Creating blob file for {file_spec.source_description} ({file_spec.size} bytes)") with file_spec.source() as fp: - blob_id = await blob_upload_file(fp, self._client.stub) + blob_id = await blob_upload_file( + fp, self._client.stub, functools.partial(self._progress_cb, progress_task_id) + ) logger.debug(f"Uploading blob file {file_spec.source_description} as {remote_filename}") request2 = api_pb2.MountPutFileRequest(data_blob_id=blob_id, sha256_hex=file_spec.sha256_hex) else: @@ -661,6 +667,7 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile: f"Uploading file {file_spec.source_description} to {remote_filename} ({file_spec.size} bytes)" ) request2 = api_pb2.MountPutFileRequest(data=file_spec.content, sha256_hex=file_spec.sha256_hex) + self._progress_cb(task_id=progress_task_id, complete=True) while (time.monotonic() - start_time) < VOLUME_PUT_FILE_CLIENT_TIMEOUT: response = await retry_transient_errors(self._client.stub.MountPutFile, request2, base_delay=1) @@ -669,7 +676,8 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile: if not response.exists: raise VolumeUploadTimeoutError(f"Uploading of {file_spec.source_description} timed out") - + else: + self._progress_cb(task_id=progress_task_id, complete=True) return api_pb2.MountFile( filename=remote_filename, sha256_hex=file_spec.sha256_hex,