Skip to content

Commit

Permalink
complete tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
kramstrom committed Jul 12, 2024
1 parent dc9336d commit 47b0a63
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 20 deletions.
38 changes: 21 additions & 17 deletions modal/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,28 +431,32 @@ def __init__(self, type: str, console: Console):
self._total_tasks = 0
self._completed_tasks = 0

def _add_sub_task(self, name, size):
def _add_sub_task(self, name: str, size: float) -> TaskID:
task_id = self._download_progress.add_task(self._type, path=name, start=True, total=size)
self._total_tasks += 1
self._overall_progress.update(self._overall_progress_task_id, total=self._total_tasks)
return task_id

def _reset_sub_task(self, task_id):
def _reset_sub_task(self, task_id: TaskID):
self._download_progress.reset(task_id)

def _complete_sub_task(self, task_id):
def _complete_progress(self):
# TODO: we could probably implement some callback progression from the server
# to get progress reports for the post processing too
# so we don't have to just spin here
self._overall_progress.remove_task(self._overall_progress_task_id)
self._spinner.update(text="Post processing...")

def _complete_sub_task(self, task_id: TaskID):
self._completed_tasks += 1
self._download_progress.remove_task(task_id)
self._overall_progress.update(
self._overall_progress_task_id,
advance=1,
description=f"({self._completed_tasks} out of {self._total_tasks} files completed)",
)
if self._completed_tasks == self._total_tasks:
self._overall_progress.remove_task(self._overall_progress_task_id)
self._spinner.update(text="Post processing...")

def _advance_sub_task(self, task_id, advance):
def _advance_sub_task(self, task_id: TaskID, advance: float):
self._download_progress.update(task_id, advance=advance)

def progress(
Expand All @@ -463,19 +467,19 @@ def progress(
size: Optional[float] = None,
reset: Optional[bool] = False,
complete: Optional[bool] = False,
):
) -> Optional[TaskID]:
if task_id is not None:
if reset:
self._reset_sub_task(task_id)
return task_id
if complete:
self._complete_sub_task(task_id)
return task_id
if advance is not None:
self._advance_sub_task(task_id, advance)
return task_id
if name is not None and size is not None:
return self._reset_sub_task(task_id)
elif complete:
return self._complete_sub_task(task_id)
elif advance is not None:
return self._advance_sub_task(task_id, advance)
elif name is not None and size is not None:
return self._add_sub_task(name, size)
elif complete:
return self._complete_progress()

raise NotImplementedError(
"Unknown action to take with args: "
+ f"name={name} "
Expand Down
1 change: 0 additions & 1 deletion modal/network_file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ async def write_file(self, remote_path: str, fp: BinaryIO, progress_cb: Callable
resumable=True,
)
else:
# TODO: start task here too
data = fp.read()
req = api_pb2.SharedVolumePutFileRequest(
shared_volume_id=self.object_id, path=remote_path, data=data, resumable=True
Expand Down
4 changes: 2 additions & 2 deletions modal/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ async def gen_file_upload_specs() -> AsyncGenerator[FileUploadSpec, None]:
# Upload files
uploads_stream = aiostream.stream.map(files_stream, self._upload_file, task_limit=20)
files: List[api_pb2.MountFile] = await aiostream.stream.list(uploads_stream)
self._progress_cb(complete=True)

request = api_pb2.VolumePutFilesRequest(
volume_id=self._volume_id,
Expand Down Expand Up @@ -666,6 +667,7 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile:
f"Uploading file {file_spec.source_description} to {remote_filename} ({file_spec.size} bytes)"
)
request2 = api_pb2.MountPutFileRequest(data=file_spec.content, sha256_hex=file_spec.sha256_hex)
self._progress_cb(task_id=progress_task_id, complete=True)

while (time.monotonic() - start_time) < VOLUME_PUT_FILE_CLIENT_TIMEOUT:
response = await retry_transient_errors(self._client.stub.MountPutFile, request2, base_delay=1)
Expand All @@ -674,8 +676,6 @@ async def _upload_file(self, file_spec: FileUploadSpec) -> api_pb2.MountFile:

if not response.exists:
raise VolumeUploadTimeoutError(f"Uploading of {file_spec.source_description} timed out")
elif not file_spec.use_blob:
self._progress_cb(task_id=progress_task_id, complete=True)
else:
self._progress_cb(task_id=progress_task_id, complete=True)
return api_pb2.MountFile(
Expand Down

0 comments on commit 47b0a63

Please sign in to comment.