Skip to content

Commit

Permalink
add persistent_workers. test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
heavengate committed Jul 7, 2021
1 parent 758dd7b commit 88e575e
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 9 deletions.
71 changes: 63 additions & 8 deletions python/paddle/fluid/dataloader/dataloader_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
from .batch_sampler import _InfiniteIterableSampler
from .collate import default_collate_fn, default_convert_fn
from .worker import ParentWatchDog, get_worker_info, _worker_loop, \
_DatasetKind, _IterableDatasetStopIteration, _WorkerException
_DatasetKind, _IterableDatasetStopIteration, _WorkerException, \
_ResumeIteration
from .flat import _flatten_batch, _restore_batch

__all__ = ['get_worker_info']
Expand Down Expand Up @@ -248,6 +249,7 @@ def __del__(self):
class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
def __init__(self, loader):
super(_DataLoaderIterMultiProcess, self).__init__(loader)
self._persistent_workers = loader._persistent_workers

assert self._num_workers > 0, "Multi-process DataLoader " \
"invalid num_workers({})".format(self._num_workers)
Expand Down Expand Up @@ -354,15 +356,63 @@ def _init_thread(self):
self._pin_memory)

self._thread_done_event = threading.Event()
# thread event is only need in multi-processing mode
self._thread = threading.Thread(
target=self._thread_loop, args=(_current_expected_place(), ))
self._thread.daemon = True
self._thread.start()

def _shutdown_worker(self, worker_id):
if self._worker_status[worker_id]:
self._indices_queues[worker_id].put(None)
self._worker_status[worker_id] = False
def _reset(self):
# data get from _data_queue will be reordered by _rcvd_idx
# for data order keeping, data index not equal _rcvd_idx
# will be cached in _task_infos
self._send_idx = 0
self._rcvd_idx = 0
self._batches_outstanding = 0
self._task_infos = {}
self._structure_infos = []

# set all worker status available
self._worker_status = [True] * self._num_workers

# resume iteration in following steps
# 1. Resume workers, clear worker caches
# put _ResumeIteration to all worker as resume iteration flag
for worker_id in range(self._num_workers):
self._indices_queues[worker_id].put(_ResumeIteration())
# clear all cache until _ResumeIteration flag
resume_worker_cnt = self._num_workers
while resume_worker_cnt > 0:
idx, data = self._get_data()
if isinstance(idx, _ResumeIteration):
assert data is None
resume_worker_cnt -= 1

# 2. Resume blocking_queue, clear blocking_queue caches
# reset blocking_queue and py_reader
self._blocking_queue.reset()
self._reader.reset()

# 3. put prefetch indices to start next epoch
# init workers and indices queues and put 2 indices in each indices queue
for _ in range(self._outstanding_capacity):
self._try_put_indices()

def _clear_and_remove_data_queue(self):
if self._data_queue is not None:
while True:
try:
self._data_queue.get_nowait()
except:
self._data_queue.cancel_join_thread()
self._data_queue.close()
break

def _shutdown_worker(self, worker_id, shutdown=False):
assert self._worker_status[worker_id] or (self._persistent_workers and
shutdown)
self._indices_queues[worker_id].put(None)
self._worker_status[worker_id] = False

def _try_shutdown_all(self, timeout=None):
if not self._shutdown:
Expand All @@ -375,7 +425,7 @@ def _try_shutdown_all(self, timeout=None):
# indices_queue
self._workers_done_event.set()
for i in range(self._num_workers):
self._shutdown_worker(i)
self._shutdown_worker(i, shutdown=True)

if not self._shutdown:
for w in self._workers:
Expand Down Expand Up @@ -457,7 +507,8 @@ def _get_data(self):
# NOTE: _rcvd_idx and _send_idx only record batches among
# workers, if batches among workers drained, there
# may also be data in blocking queue
if self._batches_outstanding < len(self._places):
if self._batches_outstanding < len(
self._places) and not self._persistent_workers:
return None
continue

Expand Down Expand Up @@ -511,7 +562,11 @@ def _get_data(self):
# is discard, outstanding batch number should be decrease
# and another indices should be put for other workers
# may still working.
self._shutdown_worker(data.worker_id)
if self._persistent_workers:
self._worker_status[data.worker_id] = False
else:
self._shutdown_worker(data.worker_id)

self._batches_outstanding -= 1
self._try_put_indices()
continue
Expand Down
11 changes: 11 additions & 0 deletions python/paddle/fluid/dataloader/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def __init__(self, worker_id):
self.worker_id = worker_id


class _ResumeIteration(object):
pass


class _DatasetKind(object):
MAP = 0
ITER = 1
Expand Down Expand Up @@ -292,6 +296,13 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
except queue.Empty:
continue

if isinstance(data, _ResumeIteration):
out_queue.put((data, None))
iterator_drained = False
fetcher = _DatasetKind.create_fetcher(
dataset_kind, dataset, auto_collate_batch, collate_fn, True)
continue

# None as poison piil, so worker event should be set
if data is None:
assert done_event.is_set() or iterator_drained, \
Expand Down
12 changes: 11 additions & 1 deletion python/paddle/fluid/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ def __init__(self,
use_buffer_reader=True,
use_shared_memory=True,
timeout=0,
worker_init_fn=None):
worker_init_fn=None,
persistent_workers=False):
self.return_list = return_list
self.collate_fn = collate_fn
self.use_buffer_reader = use_buffer_reader
Expand Down Expand Up @@ -407,6 +408,9 @@ def __init__(self,
self.pin_memory = True if use_pinned_memory(
) is None else use_pinned_memory()

self._persistent_workers = persistent_workers
self._iterator = None

def __len__(self):
if self.dataset_kind == _DatasetKind.ITER:
raise ValueError("length of IterableDataset not supported")
Expand All @@ -419,6 +423,12 @@ def __len__(self):
def __iter__(self):
if self.num_workers == 0:
return _DataLoaderIterSingleProcess(self)
elif self._persistent_workers:
if self._iterator is None:
self._iterator = _DataLoaderIterMultiProcess(self)
else:
self._iterator._reset()
return self._iterator
else:
return _DataLoaderIterMultiProcess(self)

Expand Down

0 comments on commit 88e575e

Please sign in to comment.