diff --git a/modal/_output.py b/modal/_output.py index bf7794669..1bba847d5 100644 --- a/modal/_output.py +++ b/modal/_output.py @@ -431,16 +431,23 @@ def __init__(self, type: str, console: Console): self._total_tasks = 0 self._completed_tasks = 0 - def _add_sub_task(self, name, size): + def _add_sub_task(self, name: str, size: float) -> TaskID: task_id = self._download_progress.add_task(self._type, path=name, start=True, total=size) self._total_tasks += 1 self._overall_progress.update(self._overall_progress_task_id, total=self._total_tasks) return task_id - def _reset_sub_task(self, task_id): + def _reset_sub_task(self, task_id: TaskID): self._download_progress.reset(task_id) - def _complete_sub_task(self, task_id): + def _complete_progress(self): + # TODO: we could probably implement some callback progression from the server + # to get progress reports for the post processing too + # so we don't have to just spin here + self._overall_progress.remove_task(self._overall_progress_task_id) + self._spinner.update(text="Post processing...") + + def _complete_sub_task(self, task_id: TaskID): self._completed_tasks += 1 self._download_progress.remove_task(task_id) self._overall_progress.update( @@ -448,11 +455,8 @@ def _complete_sub_task(self, task_id): advance=1, description=f"({self._completed_tasks} out of {self._total_tasks} files completed)", ) - if self._completed_tasks == self._total_tasks: - self._overall_progress.remove_task(self._overall_progress_task_id) - self._spinner.update(text="Post processing...") - def _advance_sub_task(self, task_id, advance): + def _advance_sub_task(self, task_id: TaskID, advance: float): self._download_progress.update(task_id, advance=advance) def progress( @@ -463,19 +467,19 @@ def progress( size: Optional[float] = None, reset: Optional[bool] = False, complete: Optional[bool] = False, - ): + ) -> Optional[TaskID]: if task_id is not None: if reset: - self._reset_sub_task(task_id) - return task_id - if complete: - self._complete_sub_task(task_id) - return task_id - if advance is not None: - self._advance_sub_task(task_id, advance) - return task_id - if name is not None and size is not None: + return self._reset_sub_task(task_id) + elif complete: + return self._complete_sub_task(task_id) + elif advance is not None: + return self._advance_sub_task(task_id, advance) + elif name is not None and size is not None: return self._add_sub_task(name, size) + elif complete: + return self._complete_progress() + raise NotImplementedError( "Unknown action to take with args: " + f"name={name} " diff --git a/modal/network_file_system.py b/modal/network_file_system.py index 0361407a6..ee015628d 100644 --- a/modal/network_file_system.py +++ b/modal/network_file_system.py @@ -262,7 +262,6 @@ async def write_file(self, remote_path: str, fp: BinaryIO, progress_cb: Callable resumable=True, ) else: - # TODO: start task here too data = fp.read() req = api_pb2.SharedVolumePutFileRequest( shared_volume_id=self.object_id, path=remote_path, data=data, resumable=True diff --git a/modal/volume.py b/modal/volume.py index 80bcf7fd9..c2f33d80d 100644 --- a/modal/volume.py +++ b/modal/volume.py @@ -582,6 +582,7 @@ async def gen_file_upload_specs() -> AsyncGenerator[FileUploadSpec, None]: # Upload files uploads_stream = aiostream.stream.map(files_stream, self._upload_file, task_limit=20) files: List[api_pb2.MountFile] = await aiostream.stream.list(uploads_stream) + self._progress_cb(complete=True) request = api_pb2.VolumePutFilesRequest( volume_id=self._volume_id, @@ -666,6 +667,7 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile: f"Uploading file {file_spec.source_description} to {remote_filename} ({file_spec.size} bytes)" ) request2 = api_pb2.MountPutFileRequest(data=file_spec.content, sha256_hex=file_spec.sha256_hex) + self._progress_cb(task_id=progress_task_id, complete=True) while (time.monotonic() - start_time) < VOLUME_PUT_FILE_CLIENT_TIMEOUT: response = await retry_transient_errors(self._client.stub.MountPutFile, request2, base_delay=1) @@ -674,8 +676,6 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile: if not response.exists: raise VolumeUploadTimeoutError(f"Uploading of {file_spec.source_description} timed out") - elif not file_spec.use_blob: - self._progress_cb(task_id=progress_task_id, complete=True) else: self._progress_cb(task_id=progress_task_id, complete=True) return api_pb2.MountFile(