diff --git a/python/paddle/fluid/dataloader/dataloader_iter.py b/python/paddle/fluid/dataloader/dataloader_iter.py index d315250657d7b..4a9b450be9504 100644 --- a/python/paddle/fluid/dataloader/dataloader_iter.py +++ b/python/paddle/fluid/dataloader/dataloader_iter.py @@ -37,7 +37,8 @@ from .batch_sampler import _InfiniteIterableSampler from .collate import default_collate_fn, default_convert_fn from .worker import ParentWatchDog, get_worker_info, _worker_loop, \ - _DatasetKind, _IterableDatasetStopIteration, _WorkerException + _DatasetKind, _IterableDatasetStopIteration, _WorkerException, \ + _ResumeIteration from .flat import _flatten_batch, _restore_batch __all__ = ['get_worker_info'] @@ -67,15 +68,10 @@ def __init__(self, loader): self._dataset_kind = loader.dataset_kind self._pin_memory = loader.pin_memory + self._sampler_iter = iter(self._index_sampler) if self._auto_collate_batch: - self._sampler_iter = iter(loader.batch_sampler) self._collate_fn = loader.collate_fn or default_collate_fn else: - if self._dataset_kind == _DatasetKind.MAP: - self._sampler_iter = iter(list(range(len(self._dataset)))) - else: - self._sampler_iter = iter( - _InfiniteIterableSampler(self._dataset, 1)) self._collate_fn = loader.collate_fn or default_convert_fn # LoDTensorBlockingQueue instance for create_py_reader and a thread @@ -87,6 +83,16 @@ def __init__(self, loader): self._thread = None self._thread_done_event = threading.Event() + @property + def _index_sampler(self): + if self._auto_collate_batch: + return self._batch_sampler + else: + if self._dataset_kind == _DatasetKind.MAP: + return list(range(len(self._dataset))) + else: + return _InfiniteIterableSampler(self._dataset, 1) + def __iter__(self): return self @@ -242,6 +248,9 @@ class _DataLoaderIterMultiProcess(_DataLoaderIterBase): def __init__(self, loader): super(_DataLoaderIterMultiProcess, self).__init__(loader) + self._persistent_workers = loader._persistent_workers + self._resume_worker_cnt = 0 + assert self._num_workers > 0, "Multi-process DataLoader " \ "invalid num_workers({})".format(self._num_workers) @@ -336,13 +345,65 @@ def _init_thread(self): self._pin_memory) self._thread_done_event = threading.Event() + # thread event is only need in multi-processing mode self._thread = threading.Thread( target=self._thread_loop, args=(_current_expected_place(), )) self._thread.daemon = True self._thread.start() - def _shutdown_worker(self, worker_id): - if self._worker_status[worker_id]: + def _reset(self): + # resume iteration in following steps + # 1. Resume workers, clear worker caches + # put _ResumeIteration to all worker as resume iteration flag + with self._thread_lock: + self._resume_worker_cnt = self._num_workers + for worker_id in range(self._num_workers): + self._indices_queues[worker_id].put(_ResumeIteration()) + self._batches_outstanding += 1 + # all flag will be check in _thread_loop, simply wait here + while self._resume_worker_cnt > 0: + time.sleep(0.5) + + # 2. clear blocking_queue caches + # in order not to restart the thread, we just clear + # the blocking_queue cachees instead of recreating one + while self._blocking_queue.size() >= len(self._places): + if in_dygraph_mode(): + self._reader.read_next_var_list() + elif self._return_list: + self._reader.read_next_list() + else: + data = self._reader.read_next() + + # 3. reset all states + self._send_idx = 0 + self._rcvd_idx = 0 + self._batches_outstanding = 0 + self._task_infos = {} + self._structure_infos = [] + + # set all worker status available + self._worker_status = [True] * self._num_workers + + # 4. reset _sampler_iter and put prefetch indices to start next epoch + # init workers and indices queues and put 2 indices in each indices queue + self._sampler_iter = iter(self._index_sampler) + for _ in range(self._outstanding_capacity): + self._try_put_indices() + + def _clear_and_remove_data_queue(self): + if self._data_queue is not None: + while True: + try: + self._data_queue.get_nowait() + except: + self._data_queue.cancel_join_thread() + self._data_queue.close() + break + + def _shutdown_worker(self, worker_id, shutdown=False): + if self._worker_status[worker_id] or (self._persistent_workers and + shutdown): self._indices_queues[worker_id].put(None) self._worker_status[worker_id] = False @@ -357,7 +418,7 @@ def _try_shutdown_all(self, timeout=None): # indices_queue self._workers_done_event.set() for i in range(self._num_workers): - self._shutdown_worker(i) + self._shutdown_worker(i, shutdown=True) if not self._shutdown: for w in self._workers: @@ -392,6 +453,10 @@ def _thread_loop(self, legacy_expected_place): if batch is None: self._exit_thread_expectedly() else: + if isinstance(batch, _ResumeIteration): + assert self._resume_worker_cnt > 0 + self._resume_worker_cnt -= 1 + continue try: # pack as LoDTensorArray array = core.LoDTensorArray() @@ -412,7 +477,7 @@ def _thread_loop(self, legacy_expected_place): if not self._blocking_queue.push(array): self._blocking_queue.close() - except: + except Exception as e: self._exit_thread_unexpectedly() six.reraise(*sys.exc_info()) finally: @@ -428,7 +493,6 @@ def _get_data(self): # batch indices and increase _rcvd_idx if self._dataset_kind == _DatasetKind.ITER: while self._rcvd_idx < self._send_idx: - sys.stdout.flush() info = self._task_infos[self._rcvd_idx] if len(info) == 3 or self._worker_status[info[0]]: break @@ -436,12 +500,16 @@ def _get_data(self): self._rcvd_idx += 1 self._batches_outstanding -= 1 else: - # NOTE: _rcvd_idx and _send_idx only record batches among - # workers, if batches among workers drained, there - # may also be data in blocking queue - if self._batches_outstanding < len(self._places): - return None - continue + # NOTE: in persistent workers mode, do not check data + # drained here, simply let it go to _data_queue + # reading to get _ResumeIteration + if not self._persistent_workers: + # NOTE: _rcvd_idx and _send_idx only record batches among + # workers, if batches among workers drained, there + # may also be data in blocking queue + if self._batches_outstanding < len(self._places): + return None + continue if self._rcvd_idx in self._task_infos and \ len(self._task_infos[self._rcvd_idx]) == 3: @@ -493,12 +561,20 @@ def _get_data(self): # is discard, outstanding batch number should be decrease # and another indices should be put for other workers # may still working. - self._shutdown_worker(data.worker_id) - self._batches_outstanding -= 1 + if self._persistent_workers: + self._worker_status[data.worker_id] = False + else: + self._shutdown_worker(data.worker_id) + self._batches_outstanding -= 1 self._try_put_indices() continue idx, batch, structure = data + + if isinstance(idx, _ResumeIteration) and batch is None \ + and structure is None: + return idx + if isinstance(batch, _WorkerException): self._exit_thread_unexpectedly() batch.reraise() @@ -557,8 +633,11 @@ def __next__(self): # set _thread_done_event here, py_reader will raise StopIteration, # end workers and indices_queues in StopIteration handling if self._batches_outstanding < len(self._places): - self._thread_done_event.set() - self._blocking_queue.close() + if self._persistent_workers: + raise StopIteration + else: + self._thread_done_event.set() + self._blocking_queue.close() if in_dygraph_mode(): data = self._reader.read_next_var_list() @@ -583,8 +662,9 @@ def __next__(self): self._on_output_batch() return data except StopIteration: - self._reader.shutdown() - self._try_shutdown_all() + if not self._persistent_workers: + self._reader.shutdown() + self._try_shutdown_all() six.reraise(*sys.exc_info()) # python2 compatibility diff --git a/python/paddle/fluid/dataloader/worker.py b/python/paddle/fluid/dataloader/worker.py index 037cf2c4b12d2..66ca4150460d7 100644 --- a/python/paddle/fluid/dataloader/worker.py +++ b/python/paddle/fluid/dataloader/worker.py @@ -36,6 +36,10 @@ def __init__(self, worker_id): self.worker_id = worker_id +class _ResumeIteration(object): + pass + + class _DatasetKind(object): MAP = 0 ITER = 1 @@ -292,6 +296,13 @@ def _worker_loop(dataset, dataset_kind, indices_queue, out_queue, done_event, except queue.Empty: continue + if isinstance(data, _ResumeIteration): + out_queue.put((data, None, None)) + iterator_drained = False + fetcher = _DatasetKind.create_fetcher( + dataset_kind, dataset, auto_collate_batch, collate_fn, True) + continue + # None as poison piil, so worker event should be set if data is None: assert done_event.is_set() or iterator_drained, \ diff --git a/python/paddle/fluid/reader.py b/python/paddle/fluid/reader.py index d5a23cfbdb941..7076ef22ba605 100644 --- a/python/paddle/fluid/reader.py +++ b/python/paddle/fluid/reader.py @@ -325,7 +325,8 @@ def __init__(self, use_buffer_reader=True, use_shared_memory=True, timeout=0, - worker_init_fn=None): + worker_init_fn=None, + persistent_workers=False): self.return_list = return_list self.collate_fn = collate_fn self.use_buffer_reader = use_buffer_reader @@ -407,6 +408,9 @@ def __init__(self, self.pin_memory = True if use_pinned_memory( ) is None else use_pinned_memory() + self._persistent_workers = persistent_workers + self._iterator = None + def __len__(self): if self.dataset_kind == _DatasetKind.ITER: raise ValueError("length of IterableDataset not supported") @@ -419,6 +423,12 @@ def __len__(self): def __iter__(self): if self.num_workers == 0: return _DataLoaderIterSingleProcess(self) + elif self._persistent_workers: + if self._iterator is None: + self._iterator = _DataLoaderIterMultiProcess(self) + else: + self._iterator._reset() + return self._iterator else: return _DataLoaderIterMultiProcess(self) diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py index c89354adf751c..fcc7c17ce06a7 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dynamic.py @@ -66,7 +66,7 @@ def forward(self, image): class TestDygraphDataLoader(unittest.TestCase): - def run_main(self, num_workers, places): + def run_main(self, num_workers, places, persistent_workers): fluid.default_startup_program().random_seed = 1 fluid.default_main_program().random_seed = 1 with fluid.dygraph.guard(places[0]): @@ -78,7 +78,8 @@ def run_main(self, num_workers, places): dataset, num_workers=num_workers, batch_size=BATCH_SIZE, - drop_last=True) + drop_last=True, + persistent_workers=persistent_workers) assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) step_list = [] @@ -110,20 +111,25 @@ def run_main(self, num_workers, places): def test_main(self): # dynamic graph do not run with_data_parallel for p in prepare_places(False): - results = [] - for num_workers in [0, 2]: - print(self.__class__.__name__, p, num_workers) - sys.stdout.flush() - ret = self.run_main(num_workers=num_workers, places=p) - results.append(ret) - diff = np.max( - np.abs(results[0]['loss'] - results[1]['loss']) / - np.abs(results[0]['loss'])) - self.assertLess(diff, 1e-2) + for persistent_workers in [False, True]: + results = [] + for num_workers in [0, 2]: + print(self.__class__.__name__, p, num_workers, + persistent_workers) + sys.stdout.flush() + ret = self.run_main( + num_workers=num_workers, + places=p, + persistent_workers=persistent_workers) + results.append(ret) + diff = np.max( + np.abs(results[0]['loss'] - results[1]['loss']) / + np.abs(results[0]['loss'])) + self.assertLess(diff, 1e-2) class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader): - def run_main(self, num_workers, places): + def run_main(self, num_workers, places, persistent_workers): fluid.default_startup_program().random_seed = 1 fluid.default_main_program().random_seed = 1 with fluid.dygraph.guard(places[0]): @@ -135,7 +141,8 @@ def run_main(self, num_workers, places): dataset, num_workers=num_workers, batch_size=None, - drop_last=True) + drop_last=True, + persistent_workers=persistent_workers) assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) step_list = [] diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py index 3bb3e843b1b11..490e95a0f0be2 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py @@ -66,7 +66,7 @@ def forward(self, image): class TestDygraphDataLoader(unittest.TestCase): - def run_main(self, num_workers, places): + def run_main(self, num_workers, places, persistent_workers): fluid.default_startup_program().random_seed = 1 fluid.default_main_program().random_seed = 1 with fluid.dygraph.guard(places[0]): @@ -78,7 +78,8 @@ def run_main(self, num_workers, places): dataset, num_workers=num_workers, batch_size=BATCH_SIZE, - drop_last=True) + drop_last=True, + persistent_workers=persistent_workers) step_list = [] loss_list = [] @@ -109,18 +110,23 @@ def run_main(self, num_workers, places): def test_main(self): # dynamic graph do not run with_data_parallel for p in prepare_places(False): - results = [] - for num_workers in [0, 2]: - print(self.__class__.__name__, p, num_workers) - sys.stdout.flush() - ret = self.run_main(num_workers=num_workers, places=p) - results.append(ret) - assert results[0]['loss'].shape[0] * 2 == results[1]['loss'].shape[ - 0] + for persistent_workers in [False, True]: + results = [] + for num_workers in [0, 2]: + print(self.__class__.__name__, p, num_workers, + persistent_workers) + sys.stdout.flush() + ret = self.run_main( + num_workers=num_workers, + places=p, + persistent_workers=persistent_workers) + results.append(ret) + assert results[0]['loss'].shape[0] * 2 == results[1][ + 'loss'].shape[0] class TestDygraphDataLoaderWithBatchedDataset(TestDygraphDataLoader): - def run_main(self, num_workers, places): + def run_main(self, num_workers, places, persistent_workers): fluid.default_startup_program().random_seed = 1 fluid.default_main_program().random_seed = 1 with fluid.dygraph.guard(places[0]): @@ -132,7 +138,8 @@ def run_main(self, num_workers, places): dataset, num_workers=num_workers, batch_size=None, - drop_last=True) + drop_last=True, + persistent_workers=persistent_workers) step_list = [] loss_list = [] diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py index fe66f1733546b..9e09c5e3a1d44 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_static.py @@ -93,14 +93,14 @@ def prepare_places(with_data_parallel, with_cpu=False, with_gpu=True): if with_gpu and fluid.core.is_compiled_with_cuda(): tmp = fluid.cuda_places()[:2] assert len(tmp) > 0, "no gpu detected" - if with_data_parallel: + if with_data_parallel and len(tmp) > 1: places.append(tmp) places.append([tmp[0]]) return places class TestStaticDataLoader(unittest.TestCase): - def run_main(self, num_workers, places): + def run_main(self, num_workers, places, persistent_workers): scope = fluid.Scope() with fluid.scope_guard(scope): startup_prog, main_prog, image, label, loss = simple_fc_net_static() @@ -113,7 +113,8 @@ def run_main(self, num_workers, places): num_workers=num_workers, batch_size=BATCH_SIZE, return_list=False, - drop_last=True) + drop_last=True, + persistent_workers=persistent_workers) # assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) exe = fluid.Executor(place=places[0]) @@ -158,14 +159,19 @@ def run_main(self, num_workers, places): def test_main(self): for p in prepare_places(True): - results = [] - for num_workers in [0, 2]: - print(self.__class__.__name__, p, num_workers) - sys.stdout.flush() - ret = self.run_main(num_workers=num_workers, places=p) - results.append(ret) - assert results[0]['loss'].shape[0] * 2 == results[1]['loss'].shape[ - 0] + for persistent_workers in [False, True]: + results = [] + for num_workers in [0, 2]: + print(self.__class__.__name__, p, num_workers, + persistent_workers) + sys.stdout.flush() + ret = self.run_main( + num_workers=num_workers, + places=p, + persistent_workers=persistent_workers) + results.append(ret) + assert results[0]['loss'].shape[0] * 2 == results[1][ + 'loss'].shape[0] class RandomBatchedDataset(IterableDataset): @@ -188,7 +194,7 @@ def __iter__(self): class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader): - def run_main(self, num_workers, places): + def run_main(self, num_workers, places, persistent_workers): scope = fluid.Scope() with fluid.scope_guard(scope): startup_prog, main_prog, image, label, loss = simple_fc_net_static() @@ -201,7 +207,8 @@ def run_main(self, num_workers, places): num_workers=num_workers, batch_size=None, return_list=False, - drop_last=True) + drop_last=True, + persistent_workers=persistent_workers) exe = fluid.Executor(place=places[0]) exe.run(startup_prog) diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py index f5bccf7ab09b6..9f73ee041e0e2 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_static.py @@ -94,14 +94,14 @@ def prepare_places(with_data_parallel, with_cpu=False, with_gpu=True): if with_gpu and fluid.core.is_compiled_with_cuda(): tmp = fluid.cuda_places()[:2] assert len(tmp) > 0, "no gpu detected" - if with_data_parallel: + if with_data_parallel and len(tmp) > 1: places.append(tmp) places.append([tmp[0]]) return places class TestStaticDataLoader(unittest.TestCase): - def run_main(self, num_workers, places, use_pe=True): + def run_main(self, num_workers, places, persistent_workers, use_pe=True): scope = fluid.Scope() with fluid.scope_guard(scope): startup_prog, main_prog, image, label, loss = simple_fc_net_static() @@ -114,7 +114,8 @@ def run_main(self, num_workers, places, use_pe=True): num_workers=num_workers, batch_size=BATCH_SIZE, return_list=False, - drop_last=True) + drop_last=True, + persistent_workers=persistent_workers) assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) exe = fluid.Executor(place=places[0]) @@ -162,16 +163,21 @@ def run_main(self, num_workers, places, use_pe=True): def test_main(self): for p in prepare_places(True): - results = [] - for num_workers in [0, 2]: - print(self.__class__.__name__, p, num_workers) - sys.stdout.flush() - ret = self.run_main(num_workers=num_workers, places=p) - results.append(ret) - diff = np.max( - np.abs(results[0]['loss'] - results[1]['loss']) / - np.abs(results[0]['loss'])) - self.assertLess(diff, 1e-2) + for persistent_workers in [True, False]: + results = [] + for num_workers in [0, 2]: + print(self.__class__.__name__, p, num_workers, + persistent_workers) + sys.stdout.flush() + ret = self.run_main( + num_workers=num_workers, + places=p, + persistent_workers=persistent_workers) + results.append(ret) + diff = np.max( + np.abs(results[0]['loss'] - results[1]['loss']) / + np.abs(results[0]['loss'])) + self.assertLess(diff, 1e-2) class TestStaticDataLoaderReturnList(unittest.TestCase): @@ -241,7 +247,7 @@ def __len__(self): class TestStaticDataLoaderWithBatchedDataset(TestStaticDataLoader): - def run_main(self, num_workers, places): + def run_main(self, num_workers, places, persistent_workers): scope = fluid.Scope() with fluid.scope_guard(scope): startup_prog, main_prog, image, label, loss = simple_fc_net_static() @@ -254,7 +260,8 @@ def run_main(self, num_workers, places): num_workers=num_workers, batch_size=None, return_list=False, - drop_last=True) + drop_last=True, + persistent_workers=persistent_workers) assert len(dataloader) == int(SAMPLE_NUM / BATCH_SIZE) exe = fluid.Executor(place=places[0])