Skip to content

Commit

Permalink
add persistent_workers (#34017)
Browse files Browse the repository at this point in the history
* add persistent_workers. test=develop
  • Loading branch information
heavengate committed Jul 29, 2021
1 parent b451ff2 commit 76710e5
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 79 deletions.
128 changes: 104 additions & 24 deletions python/paddle/fluid/dataloader/dataloader_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
from .batch_sampler import _InfiniteIterableSampler
from .collate import default_collate_fn, default_convert_fn
from .worker import ParentWatchDog, get_worker_info, _worker_loop, \
_DatasetKind, _IterableDatasetStopIteration, _WorkerException
_DatasetKind, _IterableDatasetStopIteration, _WorkerException, \
_ResumeIteration
from .flat import _flatten_batch, _restore_batch

__all__ = ['get_worker_info']
Expand Down Expand Up @@ -67,15 +68,10 @@ def __init__(self, loader):
self._dataset_kind = loader.dataset_kind
self._pin_memory = loader.pin_memory

self._sampler_iter = iter(self._index_sampler)
if self._auto_collate_batch:
self._sampler_iter = iter(loader.batch_sampler)
self._collate_fn = loader.collate_fn or default_collate_fn
else:
if self._dataset_kind == _DatasetKind.MAP:
self._sampler_iter = iter(list(range(len(self._dataset))))
else:
self._sampler_iter = iter(
_InfiniteIterableSampler(self._dataset, 1))
self._collate_fn = loader.collate_fn or default_convert_fn

# LoDTensorBlockingQueue instance for create_py_reader and a thread
Expand All @@ -87,6 +83,16 @@ def __init__(self, loader):
self._thread = None
self._thread_done_event = threading.Event()

@property
def _index_sampler(self):
if self._auto_collate_batch:
return self._batch_sampler
else:
if self._dataset_kind == _DatasetKind.MAP:
return list(range(len(self._dataset)))
else:
return _InfiniteIterableSampler(self._dataset, 1)

def __iter__(self):
return self

Expand Down Expand Up @@ -242,6 +248,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase):
def __init__(self, loader):
super(_DataLoaderIterMultiProcess, self).__init__(loader)

self._persistent_workers = loader._persistent_workers
self._resume_worker_cnt = 0

assert self._num_workers > 0, "Multi-process DataLoader " \
"invalid num_workers({})".format(self._num_workers)

Expand Down Expand Up @@ -336,13 +345,65 @@ def _init_thread(self):
self._pin_memory)

self._thread_done_event = threading.Event()
# thread event is only need in multi-processing mode
self._thread = threading.Thread(
target=self._thread_loop, args=(_current_expected_place(), ))
self._thread.daemon = True
self._thread.start()

def _shutdown_worker(self, worker_id):
if self._worker_status[worker_id]:
def _reset(self):
# resume iteration in following steps
# 1. Resume workers, clear worker caches
# put _ResumeIteration to all worker as resume iteration flag
with self._thread_lock:
self._resume_worker_cnt = self._num_workers
for worker_id in range(self._num_workers):
self._indices_queues[worker_id].put(_ResumeIteration())
self._batches_outstanding += 1
# all flag will be check in _thread_loop, simply wait here
while self._resume_worker_cnt > 0:
time.sleep(0.5)

# 2. clear blocking_queue caches
# in order not to restart the thread, we just clear
# the blocking_queue cachees instead of recreating one
while self._blocking_queue.size() >= len(self._places):
if in_dygraph_mode():
self._reader.read_next_var_list()
elif self._return_list:
self._reader.read_next_list()
else:
data = self._reader.read_next()

# 3. reset all states
self._send_idx = 0
self._rcvd_idx = 0
self._batches_outstanding = 0
self._task_infos = {}
self._structure_infos = []

# set all worker status available
self._worker_status = [True] * self._num_workers

# 4. reset _sampler_iter and put prefetch indices to start next epoch
# init workers and indices queues and put 2 indices in each indices queue
self._sampler_iter = iter(self._index_sampler)
for _ in range(self._outstanding_capacity):
self._try_put_indices()

def _clear_and_remove_data_queue(self):
if self._data_queue is not None:
while True:
try:
self._data_queue.get_nowait()
except:
self._data_queue.cancel_join_thread()
self._data_queue.close()
break

def _shutdown_worker(self, worker_id, shutdown=False):
if self._worker_status[worker_id] or (self._persistent_workers and
shutdown):
self._indices_queues[worker_id].put(None)
self._worker_status[worker_id] = False

Expand All @@ -357,7 +418,7 @@ def _try_shutdown_all(self, timeout=None):
# indices_queue
self._workers_done_event.set()
for i in range(self._num_workers):
self._shutdown_worker(i)
self._shutdown_worker(i, shutdown=True)

if not self._shutdown:
for w in self._workers:
Expand Down Expand Up @@ -392,6 +453,10 @@ def _thread_loop(self, legacy_expected_place):
if batch is None:
self._exit_thread_expectedly()
else:
if isinstance(batch, _ResumeIteration):
assert self._resume_worker_cnt > 0
self._resume_worker_cnt -= 1
continue
try:
# pack as LoDTensorArray
array = core.LoDTensorArray()
Expand All @@ -412,7 +477,7 @@ def _thread_loop(self, legacy_expected_place):

if not self._blocking_queue.push(array):
self._blocking_queue.close()
except:
except Exception as e:
self._exit_thread_unexpectedly()
six.reraise(*sys.exc_info())
finally:
Expand All @@ -428,20 +493,23 @@ def _get_data(self):
# batch indices and increase _rcvd_idx
if self._dataset_kind == _DatasetKind.ITER:
while self._rcvd_idx < self._send_idx:
sys.stdout.flush()
info = self._task_infos[self._rcvd_idx]
if len(info) == 3 or self._worker_status[info[0]]:
break
del self._task_infos[self._rcvd_idx]
self._rcvd_idx += 1
self._batches_outstanding -= 1
else:
# NOTE: _rcvd_idx and _send_idx only record batches among
# workers, if batches among workers drained, there
# may also be data in blocking queue
if self._batches_outstanding < len(self._places):
return None
continue
# NOTE: in persistent workers mode, do not check data
# drained here, simply let it go to _data_queue
# reading to get _ResumeIteration
if not self._persistent_workers:
# NOTE: _rcvd_idx and _send_idx only record batches among
# workers, if batches among workers drained, there
# may also be data in blocking queue
if self._batches_outstanding < len(self._places):
return None
continue

if self._rcvd_idx in self._task_infos and \
len(self._task_infos[self._rcvd_idx]) == 3:
Expand Down Expand Up @@ -493,12 +561,20 @@ def _get_data(self):
# is discard, outstanding batch number should be decrease
# and another indices should be put for other workers
# may still working.
self._shutdown_worker(data.worker_id)
self._batches_outstanding -= 1
if self._persistent_workers:
self._worker_status[data.worker_id] = False
else:
self._shutdown_worker(data.worker_id)
self._batches_outstanding -= 1
self._try_put_indices()
continue

idx, batch, structure = data

if isinstance(idx, _ResumeIteration) and batch is None \
and structure is None:
return idx

if isinstance(batch, _WorkerException):
self._exit_thread_unexpectedly()
batch.reraise()
Expand Down Expand Up @@ -557,8 +633,11 @@ def __next__(self):
# set _thread_done_event here, py_reader will raise StopIteration,
# end workers and indices_queues in StopIteration handling
if self._batches_outstanding < len(self._places):
self._thread_done_event.set()
self._blocking_queue.close()
if self._persistent_workers:
raise StopIteration
else:
self._thread_done_event.set()
self._blocking_queue.close()

if in_dygraph_mode():
data = self._reader.read_next_var_list()
Expand All @@ -583,8 +662,9 @@ def __next__(self):
self._on_output_batch()
return data
except StopIteration:
self._reader.shutdown()
self._try_shutdown_all()
if not self._persistent_workers:
self._reader.shutdown()
self._try_shutdown_all()
six.reraise(*sys.exc_info())

# python2 compatibility
Expand Down
11 changes: 11 additions & 0 deletions python/paddle/fluid/dataloader/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def __init__(self, worker_id):
self.worker_id = worker_id


class _ResumeIteration(object):
pass


class _DatasetKind(object):
MAP = 0
ITER = 1
Expand Down Expand Up @@ -292,6 +296,13 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event,
except queue.Empty:
continue

if isinstance(data, _ResumeIteration):
out_queue.put((data, None, None))
iterator_drained = False
fetcher = _DatasetKind.create_fetcher(
dataset_kind, dataset, auto_collate_batch, collate_fn, True)
continue

# None as poison piil, so worker event should be set
if data is None:
assert done_event.is_set() or iterator_drained, \
Expand Down
12 changes: 11 additions & 1 deletion python/paddle/fluid/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ def __init__(self,
use_buffer_reader=True,
use_shared_memory=True,
timeout=0,
worker_init_fn=None):
worker_init_fn=None,
persistent_workers=False):
self.return_list = return_list
self.collate_fn = collate_fn
self.use_buffer_reader = use_buffer_reader
Expand Down Expand Up @@ -407,6 +408,9 @@ def __init__(self,
self.pin_memory = True if use_pinned_memory(
) is None else use_pinned_memory()

self._persistent_workers = persistent_workers
self._iterator = None

def __len__(self):
if self.dataset_kind == _DatasetKind.ITER:
raise ValueError("length of IterableDataset not supported")
Expand All @@ -419,6 +423,12 @@ def __len__(self):
def __iter__(self):
if self.num_workers == 0:
return _DataLoaderIterSingleProcess(self)
elif self._persistent_workers:
if self._iterator is None:
self._iterator = _DataLoaderIterMultiProcess(self)
else:
self._iterator._reset()
return self._iterator
else:
return _DataLoaderIterMultiProcess(self)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(self, image):


class TestDygraphDataLoader(unittest.TestCase):
def run_main(self, num_workers, places):
def run_main(self, num_workers, places, persistent_workers):
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
with fluid.dygraph.guard(places[0]):
Expand All @@ -78,7 +78,8 @@ def run_main(self, num_workers, places):
dataset,
num_workers=num_workers,
batch_size=BATCH_SIZE,
drop_last=True)
drop_last=True,
persistent_workers=persistent_workers)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)

step_list = []
Expand Down Expand Up @@ -110,20 +111,25 @@ def run_main(self, num_workers, places):
def test_main(self):
# dynamic graph do not run with_data_parallel
for p in prepare_places(False):
results = []
for num_workers in [0, 2]:
print(self.__class__.__name__, p, num_workers)
sys.stdout.flush()
ret = self.run_main(num_workers=num_workers, places=p)
results.append(ret)
diff = np.max(
np.abs(results[0]['loss'] - results[1]['loss']) /
np.abs(results[0]['loss']))
self.assertLess(diff, 1e-2)
for persistent_workers in [False, True]:
results = []
for num_workers in [0, 2]:
print(self.__class__.__name__, p, num_workers,
persistent_workers)
sys.stdout.flush()
ret = self.run_main(
num_workers=num_workers,
places=p,
persistent_workers=persistent_workers)
results.append(ret)
diff = np.max(
np.abs(results[0]['loss'] - results[1]['loss']) /
np.abs(results[0]['loss']))
self.assertLess(diff, 1e-2)


class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader):
def run_main(self, num_workers, places):
def run_main(self, num_workers, places, persistent_workers):
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
with fluid.dygraph.guard(places[0]):
Expand All @@ -135,7 +141,8 @@ def run_main(self, num_workers, places):
dataset,
num_workers=num_workers,
batch_size=None,
drop_last=True)
drop_last=True,
persistent_workers=persistent_workers)
assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE)

step_list = []
Expand Down
Loading

0 comments on commit 76710e5

Please sign in to comment.