Skip to content

Commit

Permalink
balbal
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Aug 5, 2024
1 parent affcba6 commit 2f2f98a
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions python/dgl/graphbolt/internal/datapipe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time

from collections import deque
from typing import final, List, Set, Type
from typing import final, List, Set, Type # pylint: disable=no-name-in-module

from torch.utils.data import functional_datapipe, IterDataPipe, MapDataPipe
from torch.utils.data.graph import DataPipe, DataPipeGraph, traverse_dps
Expand Down Expand Up @@ -209,17 +209,19 @@ def __init__(self, source_datapipe, buffer_size: int):
@functional_datapipe("prefetch")
class PrefetcherIterDataPipe(IterDataPipe):
r"""
Prefetches elements from the source DataPipe and puts them into a buffer (functional name: ``prefetch``).
Prefetching performs the operations (e.g. I/O, computations) of the DataPipes up to this one ahead of time
and stores the result in the buffer, ready to be consumed by the subsequent DataPipe. It has no effect aside
from getting the sample ready ahead of time.
Prefetches elements from the source DataPipe and puts them into a buffer
(functional name: ``prefetch``). Prefetching performs the operations (e.g.
I/O, computations) of the DataPipes up to this one ahead of time and stores
the result in the buffer, ready to be consumed by the subsequent DataPipe.
It has no effect aside from getting the sample ready ahead of time.
This is used by ``MultiProcessingReadingService`` when the arguments
``worker_prefetch_cnt`` (for prefetching at each worker process) or
``main_prefetch_cnt`` (for prefetching at the main loop) are greater than 0.
Beyond the built-in use cases, this can be useful to put after I/O DataPipes that have
expensive I/O operations (e.g. takes a long time to request a file from a remote server).
Beyond the built-in use cases, this can be useful to put after I/O DataPipes
that have expensive I/O operations (e.g. takes a long time to request a file
from a remote server).
Args:
source_datapipe: IterDataPipe from which samples are prefetched
Expand All @@ -241,7 +243,9 @@ def __init__(self, source_datapipe, buffer_size: int = 10):
self.prefetch_data: Optional[_PrefetchData] = None

@staticmethod
def thread_worker(prefetch_data: _PrefetchData):
def thread_worker(
prefetch_data: _PrefetchData,
): # pylint: disable=missing-function-docstring
itr = iter(prefetch_data.source_datapipe)
while not prefetch_data.stop_iteration:
# Run if not paused
Expand All @@ -253,7 +257,7 @@ def thread_worker(prefetch_data: _PrefetchData):
try:
item = next(itr)
prefetch_data.prefetch_buffer.append(item)
except Exception as e:
except Exception as e: # pylint: disable=broad-except
prefetch_data.run_prefetcher = False
prefetch_data.stop_iteration = True
prefetch_data.prefetch_buffer.append(e)
Expand Down Expand Up @@ -320,10 +324,10 @@ def __setstate__(self, state):
self.thread = None

@final
def reset(self):
def reset(self): # pylint: disable=missing-function-docstring
self.shutdown()

def pause(self):
def pause(self): # pylint: disable=missing-function-docstring
if self.thread is not None:
assert self.prefetch_data is not None
self.prefetch_data.run_prefetcher = False
Expand All @@ -333,7 +337,7 @@ def pause(self):
time.sleep(PRODUCER_SLEEP_INTERVAL * 10)

@final
def resume(self):
def resume(self): # pylint: disable=missing-function-docstring
if (
self.thread is not None
and self.prefetch_data is not None
Expand All @@ -346,7 +350,7 @@ def resume(self):
self.prefetch_data.paused = False

@final
def shutdown(self):
def shutdown(self): # pylint: disable=missing-function-docstring
if hasattr(self, "prefetch_data") and self.prefetch_data is not None:
self.prefetch_data.run_prefetcher = False
self.prefetch_data.stop_iteration = True
Expand Down

0 comments on commit 2f2f98a

Please sign in to comment.