Skip to content

Commit

Permalink
[Auto Parallel] Improve the fine-grained APIs (#46552)
Browse files Browse the repository at this point in the history
* [Auto Parallel] Suppport different dataloaders

* [Auto Parallel] Add num_shards config for dataset

* [Auto Parallel] Unify the logger and outputs of Engine API

* [Auto Parallel] Fix the bugs of to_static

* [Auto Parallel] Adjust the test_to_static.py

* [Auto Parallel] Add the prepare API and replace __call__ with run

* [Auto Parallel] Improve the private implementations of Engine

* [Auto Parallel] Set capacity of dataloader for opt tuning

* [Auto Parallel] [WIP] Change the fine-grained API

* [Auto Parallel] Improve APIs to support different user cases

* [Auto Parallel] Add removed config

* [Auto Parallel] Add imports

* [Auto Parallel] Fix bugs for to_static

* [Auto Parallel] Remove unnecessary imports
  • Loading branch information
aoyulong committed Oct 12, 2022
1 parent 01baa0b commit 686fa07
Show file tree
Hide file tree
Showing 9 changed files with 969 additions and 369 deletions.
7 changes: 7 additions & 0 deletions python/paddle/distributed/auto_parallel/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,10 @@ def set_field_default_config(category, field, default_value):
set_field_default_config(TUNING, "profile_end_step", 1)
set_field_default_config(TUNING, "run_after_tuning", True)
set_field_default_config(TUNING, "verbose", True)

#########################################
# dataset configuration
#########################################
DATASET = "dataset"
set_field_default_config(DATASET, "enable", False)
set_field_default_config(DATASET, "num_shards", 1)
183 changes: 133 additions & 50 deletions python/paddle/distributed/auto_parallel/dist_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,11 @@

import paddle
from paddle.io import BatchSampler, IterableDataset
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler, DistributedBatchSampler
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn


class DistributedDataLoader(metaclass=abc.ABCMeta):

def __init__(self, dataset, batch_size=1, epochs=1, drop_last=False):
if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER
else:
self.dataset_kind = _DatasetKind.MAP

self.dataset = dataset
self.epochs = epochs
self.drop_last = drop_last

if batch_size is None:
self.batch_size = None
self.batch_sampler = None
else:
self.batch_size = batch_size
if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler(
dataset, batch_size)
else:
self.batch_sampler = BatchSampler(dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)

self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler)
class DistributedDataLoaderBase(metaclass=abc.ABCMeta):

@abc.abstractmethod
def __iter__(self):
Expand All @@ -58,48 +31,70 @@ def __iter__(self):
def __next__(self):
raise NotImplementedError

@property
def index_sampler(self):
if self.auto_collate_batch:
return self.batch_sampler
else:
if self.dataset_kind == _DatasetKind.MAP:
return list(range(len(self.dataset)))
else:
return _InfiniteIterableSampler(self.dataset, 1)


class NonIterableGeneratorLoader(DistributedDataLoader):
class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):

def __init__(self,
dataset,
feed_list,
places,
feed_list=None,
capacity=None,
use_double_buffer=True,
iterable=True,
return_list=False,
use_multiprocess=False,
drop_last=True,
places=None,
batch_size=1,
epochs=1,
steps_per_epoch=None,
collate_fn=None,
split_data=True,
data_parallel_world_size=[],
data_parallel_rank=[],
drop_last=False,
split_data=True):
data_parallel_rank=[]):
self.dataset = dataset
self.feed_list = feed_list
self.capacity = capacity
self.use_double_buffer = use_double_buffer
self.iterable = iterable
self.return_list = return_list
self.use_multiprocess = use_multiprocess
self.drop_last = drop_last
self.places = places
self.batch_size = batch_size
self.epochs = epochs
self.steps_per_epoch = steps_per_epoch

self.collate_fn = collate_fn
self.split_data = split_data
assert len(data_parallel_world_size) == len(feed_list)
assert len(data_parallel_rank) == len(feed_list)
self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank
self.split_data = split_data

super(NonIterableGeneratorLoader,
self).__init__(dataset, batch_size, epochs, drop_last)
if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER
else:
self.dataset_kind = _DatasetKind.MAP

if self.batch_size is None:
self.batch_sampler = None
else:
if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler(
dataset, batch_size)
else:
self.batch_sampler = BatchSampler(dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)

self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler)

if self.auto_collate_batch:
self.collate_fn = collate_fn or default_collate_fn
else:
self.collate_fn = collate_fn or default_convert_fn

self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.auto_collate_batch,
self.collate_fn, self.drop_last)
Expand All @@ -115,8 +110,10 @@ def __iter__(self):
def __next__(self):
if not self._steps:
self._cur_step += 1
return None
elif self._cur_step < self._steps:
self._cur_step += 1
return None
else:
self._inner_dataloader.reset()
self.sampler_iter = iter(self.index_sampler)
Expand All @@ -138,6 +135,16 @@ def _infer_steps(self):
)
return steps_per_epoch

@property
def index_sampler(self):
if self.auto_collate_batch:
return self.batch_sampler
else:
if self.dataset_kind == _DatasetKind.MAP:
return list(range(len(self.dataset)))
else:
return _InfiniteIterableSampler(self.dataset, 1)

def _create_inner_dataloader(self):

def data_generator():
Expand Down Expand Up @@ -170,7 +177,83 @@ def data_generator():
yield partial_data

dataloader = paddle.fluid.io.DataLoader.from_generator(
feed_list=self.feed_list, capacity=70, iterable=False)
feed_list=self.feed_list,
capacity=self.capacity,
use_double_buffer=self.use_double_buffer,
# iterable=self.iterable,
iterable=False,
return_list=self.return_list,
use_multiprocess=self.use_multiprocess,
drop_last=self.drop_last)
dataloader.set_batch_generator(data_generator, self.places)

return dataloader


class DistributedDataLoader(DistributedDataLoaderBase):

def __init__(self,
dataset,
feed_list=None,
places=None,
return_list=True,
batch_size=1,
shuffle=False,
drop_last=False,
collate_fn=None,
num_workers=0,
use_buffer_reader=True,
use_shared_memory=True,
timeout=0,
worker_init_fn=None,
epochs=1,
steps_per_epoch=None,
split_data=True,
data_parallel_world_size=[],
data_parallel_rank=[]):
self.dataset = dataset
self.feed_list = feed_list
self.return_list = return_list
self.places = places
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.collate_fn = collate_fn
self.num_workers = num_workers
self.use_buffer_reader = use_buffer_reader
self.use_shared_memory = use_shared_memory
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.epochs = epochs
self.steps_per_epoch = steps_per_epoch
self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank
self.split_data = split_data
# TODO: rank info
self.batch_sampler = DistributedBatchSampler(
self.dataset, self.batch_size, self.dp_world_sizes[0],
self.dp_ranks[0], self.shuffle, self.drop_last)
self._inner_dataloader = self._create_inner_dataloader()

def __iter__(self):
return self

def __next__(self):
return next(self.data)

def _create_inner_dataloader(self):
dataloader = paddle.fluid.io.DataLoader(
self.dataset,
feed_list=self.feed_list,
places=self.places,
return_list=self.return_list,
batch_sampler=self.batch_sampler,
collate_fn=self.collate_fn,
num_workers=self.num_workers,
use_buffer_reader=self.use_buffer_reader,
use_shared_memory=self.use_shared_memory,
timeout=self.timeout,
worker_init_fn=self.worker_init_fn)
self.data = (x for x in dataloader)

return dataloader
Loading

0 comments on commit 686fa07

Please sign in to comment.