diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 004b3fb0f666b..5b8d185212c23 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -156,6 +156,10 @@ def __init__(self, topology): self.is_first_stage = (self.stage_id == 0) self.is_last_stage = (self.stage_id == (self._pp_degree - 1)) + # create p2p_groups + if self._pp_degree > 1: + self._set_p2p_group() + debug_str = "HybridParallelInfo: rank_id: %d, mp_degree: %d, " \ "sharding_degree: %d, pp_degree: %d, dp_degree: %d" % (self.global_rank, self._mp_degree, self._sharding_degree, self._pp_degree, self._dp_degree) @@ -164,27 +168,9 @@ def __init__(self, topology): self._dp_group, self._check_group) logger.info(debug_str) - # create p2p_groups and no new group - self._p2p_groups = self._build_p2p_lists() - global _HYBRID_PARALLEL_GROUP _HYBRID_PARALLEL_GROUP = self - def _build_p2p_lists(self): - comm_lists = self._topo.get_comm_list('pipe') - p2p_lists = [] - for rank in range(self.nranks): - for comm_ranks in comm_lists: - assert len(comm_ranks) == self._pp_degree - if rank in comm_ranks: - idx = comm_ranks.index(rank) - next_rank = comm_ranks[(idx + 1) % self._pp_degree] - p2p_lists.append([rank, next_rank]) - break - assert len( - p2p_lists) == self.nranks, "len(p2p_lists) should be equal nranks" - return p2p_lists - def get_parallel_mode(self): # there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel # NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and @@ -236,6 +222,41 @@ def _set_check_group(self, parallel_method="data"): return parallel_group, parallel_comm_group + def _set_p2p_group(self): + comm_lists = self._topo.get_comm_list('pipe') + + self.send_next_group = None + self.send_prev_group = None + self.recv_next_group = None + self.recv_prev_group = None + + for comm_ranks in comm_lists: + assert len(comm_ranks) == self._pp_degree + for idx, rank in enumerate(comm_ranks): + curr_rank = rank + next_rank = comm_ranks[(idx + 1) % self._pp_degree] + prev_rank = comm_ranks[(idx - 1) % self._pp_degree] + + next_group = paddle.distributed.new_group( + ranks=[curr_rank, next_rank]) + if self.global_rank == curr_rank: + self.send_next_group = next_group + elif self.global_rank == next_rank: + self.recv_prev_group = next_group + + prev_group = paddle.distributed.new_group( + ranks=[prev_rank, curr_rank]) + + if self.global_rank == curr_rank: + self.send_prev_group = prev_group + elif self.global_rank == prev_rank: + self.recv_next_group = prev_group + + assert self.send_next_group is not None + assert self.send_prev_group is not None + assert self.recv_next_group is not None + assert self.recv_prev_group is not None + def topology(self): return self._topo @@ -287,6 +308,9 @@ def get_pipe_parallel_world_size(self): def get_pipe_parallel_group(self): return self._pp_comm_group + def get_p2p_groups(self): + return self.send_next_group, self.send_prev_group, self.recv_next_group, self.recv_prev_group + # sharding parallel message: def _get_sharding_parallel_id(self): return self._topo.get_coord(self.global_rank).sharding @@ -304,9 +328,6 @@ def get_sharding_parallel_group_src_rank(self): # TODO should the src rank related to the shard rank for each parameter ? return self._sharding_comm_group.ranks[0] - def get_p2p_groups(self): - return self._p2p_groups - # check parallel group def get_check_parallel_group(self): return self._check_comm_group diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 9f2a4aaffb474..1cec106caec82 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -11,19 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -import numpy as np - import paddle import paddle.fluid as fluid from .meta_parallel_base import MetaParallelBase -from .pp_utils.utils import is_float_tensor, get_tensor_dtype, paddle_2_number, number_2_dtype -from .pp_utils import utils +from .pp_utils.utils import is_float_tensor from .parallel_layers.pp_layers import PipelineLayer from ..utils.hybrid_parallel_util import broadcast_mp_parameters from ..utils.hybrid_parallel_util import broadcast_dp_parameters from ..utils.log_util import logger -from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer +from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler from .pp_utils import p2p_communication as p2p __all__ = [] @@ -35,25 +32,9 @@ def __init__(self, layers, hcg, strategy): raise TypeError( "The Layer should be a derived class of PipelineLayer.") super(PipelineParallel, self).__init__(layers, hcg, strategy) - self.use_pipe_parallel = self._hcg.get_pipe_parallel_world_size() > 1 self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1 self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1 - self.is_pipe_partitioned = self.use_model_parallel - - self.num_caches = 0 - self.caches = { - 'inputs': [], - 'labels': [], - 'outputs': [], - } - - self.recv_cache = None - self.grad_tensors = None - - self.send_meta = True - - self.current_loss = paddle.to_tensor(0.0) self.total_loss = None self.micro_batch_size = self._strategy.pipeline_configs[ @@ -63,17 +44,14 @@ def __init__(self, layers, hcg, strategy): self.num_stages = self._hcg.get_pipe_parallel_world_size() self.stage_id = self._hcg.get_stage_id() - self.prev_stage_id = self.stage_id - 1 - self.next_stage_id = self.stage_id + 1 self.pp_group = self._hcg.get_pipe_parallel_group() + p2p.initialize_p2p_groups(hcg) self.is_first_stage = self.stage_id == 0 self.is_last_stage = (self.stage_id == (self.num_stages - 1)) self.global_rank = self._hcg.get_global_rank() - - self.mp_degree = self._hcg.get_model_parallel_world_size() - self.mp_rank = self._hcg.get_model_parallel_rank() + self.micro_batch_id = 0 logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format( self.num_stages, self.stage_id)) @@ -86,158 +64,160 @@ def __init__(self, layers, hcg, strategy): logger.info("start broadcast dp parameters") broadcast_dp_parameters(self._layers, self._hcg) - def _init_caches(self, num_caches): - if self.num_caches >= num_caches: + def _set_tensor_trainable(self, tensor): + if tensor is None: return - self.num_caches = num_caches - self.num_caches - for key in self.caches: - self.caches[key].extend([None] * self.num_caches) - def _reduce_final_loss(self): - if self.is_last_stage: - assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss" - loss = self.total_loss.clone() / self.accumulate_steps - paddle.distributed.broadcast( - loss, - src=self.global_rank, - use_calc_stream=True, - group=self.pp_group) + if isinstance(tensor, tuple): + for t in tensor: + if is_float_tensor(t): + t.stop_gradient = False else: - loss = paddle.to_tensor(0.0) - paddle.distributed.broadcast( - loss, - src=self._hcg.get_rank_from_stage(self.num_stages - 1), - use_calc_stream=True, - group=self.pp_group) - return loss + if is_float_tensor(tensor): + tensor.stop_gradient = False def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): assert isinstance(optimizer, HybridParallelOptimizer), ( 'optimizer should be HybridParallelOptimizer subclass.') - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - self.scaler = scaler + if scaler is not None: + assert isinstance(scaler, HybridParallelGradScaler), ( + 'scaler should be HybridParallelGradScaler subclass or None.') assert fluid.framework._dygraph_tracer()._has_grad, ( 'Please enable the generation of gradients.') if self.is_first_stage or self.is_last_stage: assert data is not None, ( - "For the first and the last stage, the data_iter must be set.") + "For the first and the last stage, the data must be set.") else: data = None + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.scaler = scaler self.data = data + self._layers.train() # store total loss of entire batch self.total_loss = None - self._init_caches(self.accumulate_steps) - startup_steps = self.num_stages - self.stage_id - 1 - forward_steps = 0 - backward_steps = 0 - # forward - while (forward_steps < self.accumulate_steps): - self._forward(cache_id=forward_steps) - forward_steps += 1 + # store data id for micro_batch + self.micro_batch_id = 0 - # backward - while (backward_steps < self.accumulate_steps): - self._backward(cache_id=backward_steps) - backward_steps += 1 + # Next, use the 1f1b scheduling strategy. + # this strategy is inspired by: + # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py - self._layers.allreduce_shared_weight_gradients() + startup_steps = (self.num_stages - self.stage_id - 1) + startup_steps = min(startup_steps, self.accumulate_steps) + steady_steps = self.accumulate_steps - startup_steps - # optimizer - self.train_loss = self._reduce_final_loss() - self._step() - return self.train_loss + input_buffers = [] + output_buffers = [] - def _forward(self, cache_id): - # load data - self._load_micro_batch(cache_id) - if self.stage_id != 0: - self._recv_activations(cache_id) + for step_id in range(startup_steps): + input_tensor = p2p.recv_forward() + self._set_tensor_trainable(input_tensor) - if isinstance(self.caches['inputs'][cache_id], tuple): - inputs = tuple(t for t in self.caches['inputs'][cache_id]) - else: - inputs = self.caches['inputs'][cache_id] + output_tensor = self._forward_step(input_tensor) + p2p.send_forward(output_tensor) - self._clear_grads(inputs) - outputs = self._layers.forward(inputs) + input_buffers.append(input_tensor) + output_buffers.append(output_tensor) - self.caches['outputs'][cache_id] = outputs + if steady_steps > 0: + input_tensor = p2p.recv_forward() - if self.is_last_stage: - if self._layers._loss_fn is not None: - labels = self.caches['labels'][cache_id] - outputs = self._layers._loss_fn(outputs, labels) + for i in range(steady_steps): + last_iter = (i == (steady_steps - 1)) - if self.is_last_stage: - self.current_loss = outputs - if isinstance(self.current_loss, paddle.Tensor): - if self.total_loss is None: - self.total_loss = paddle.zeros_like(self.current_loss) - self.total_loss += self.current_loss.detach() - else: - if self.total_loss is None: - self.total_loss = [ - paddle.zeros_like(v) for v in self.current_loss - ] - for idx, v in enumerate(self.current_loss): - self.total_loss[idx] += v.detach() + self._set_tensor_trainable(input_tensor) + output_tensor = self._forward_step(input_tensor) - if self.accumulate_steps > 1: - self.current_loss = self.current_loss / self.accumulate_steps + output_tensor_grad = p2p.send_forward_recv_backward(output_tensor) - self.caches['outputs'][cache_id] = self.current_loss.clone() + input_buffers.append(input_tensor) + output_buffers.append(output_tensor) - else: - self._send_activations(cache_id) + input_tensor, output_tensor = input_buffers.pop( + 0), output_buffers.pop(0) - def _backward(self, cache_id): - if self.is_last_stage: - if self.scaler: - paddle.autograd.backward( - self.scaler.scale(self.caches['outputs'][cache_id])) + input_tensor_grad = self._backward_step(input_tensor, output_tensor, + output_tensor_grad) + + if last_iter: + input_tensor = None + p2p.send_backward(input_tensor_grad) else: - paddle.autograd.backward(self.caches['outputs'][cache_id]) + input_tensor = p2p.send_backward_recv_forward(input_tensor_grad) - self._send_gradients(cache_id) - return - self._recv_gradients(cache_id) + for i in range(startup_steps): + input_tensor = input_buffers.pop(0) + output_tensor = output_buffers.pop(0) - outputs = self.caches['outputs'][cache_id] + output_tensor_grad = p2p.recv_backward() - grad_tensors = self.grad_tensors - if isinstance(outputs, tuple): - out_tensors = [t for t in outputs if is_float_tensor(t)] - assert len(out_tensors) == len(grad_tensors) - paddle.autograd.backward( - tensors=out_tensors, grad_tensors=grad_tensors) - else: - paddle.autograd.backward( - tensors=[outputs], grad_tensors=[grad_tensors]) + input_tensor_grad = self._backward_step(input_tensor, output_tensor, + output_tensor_grad) + p2p.send_backward(input_tensor_grad) - grad_tensors = None - if self.stage_id != 0: self._send_gradients(cache_id) - self.caches['outputs'][cache_id] = None + self._layers.allreduce_shared_weight_gradients() - def _broadcast_data(self, data): - if isinstance(data, paddle.Tensor): - paddle.distributed.broadcast( - data, - src=self._hcg.get_model_parallel_group_src_rank(), - group=self._hcg.get_model_parallel_group()) + self.train_loss = self._reduce_final_loss() + + # optimizer + self._optimizer_step() + return self.train_loss + + def _forward_step(self, input_tensor): + if self.stage_id == 0: + input_tensor = self._load_micro_batch(self.micro_batch_id) + + output_tensor = self._layers.forward(input_tensor) + + if self.is_last_stage: + labels = self._load_micro_batch(self.micro_batch_id) + output_tensor = self._layers._loss_fn(output_tensor, labels) + assert isinstance( + output_tensor, paddle. + Tensor), "Currently, loss_fn should obtain Paddle.Tensor dtype" + + if self.accumulate_steps > 1: + output_tensor = output_tensor / self.accumulate_steps + + if self.total_loss is None: + self.total_loss = paddle.zeros_like(output_tensor) + self.total_loss += output_tensor.detach() + + self.micro_batch_id += 1 + return output_tensor + + def _backward_step(self, input_tensor, output_tensor, output_tensor_grad): + if self.is_last_stage: + assert output_tensor_grad is None + if self.scaler: + paddle.autograd.backward(self.scaler.scale(output_tensor)) + else: + paddle.autograd.backward(output_tensor) else: - for d in data: - assert isinstance(d, paddle.Tensor) - paddle.distributed.broadcast( - d, - src=self._hcg.get_model_parallel_group_src_rank(), - group=self._hcg.get_model_parallel_group()) - return data + if isinstance(output_tensor, tuple): + outputs = [t for t in output_tensor if not t.stop_gradient] + assert len(outputs) == len(output_tensor_grad) + paddle.autograd.backward( + tensors=outputs, + grad_tensors=[t for t in output_tensor_grad]) + else: + paddle.autograd.backward( + tensors=[output_tensor], grad_tensors=[output_tensor_grad]) + + input_tensor_grad = None + if input_tensor is not None: + if isinstance(input_tensor, tuple): + input_tensor_grad = tuple( + [t.grad for t in input_tensor if not t.stop_gradient]) + else: + input_tensor_grad = input_tensor.grad + return input_tensor_grad def _load_micro_batch(self, cache_id): inputs = self.data @@ -246,8 +226,6 @@ def _load_micro_batch(self, cache_id): if self.is_first_stage: assert len(inputs) == 2, "length of input should be 2" - if self.use_model_parallel: - inputs[0] = self._broadcast_data(inputs[0]) if isinstance(inputs[0], tuple): batch_size = inputs[0][0].shape[0] assert self.micro_batch_size * self.accumulate_steps == batch_size, ( @@ -255,332 +233,51 @@ def _load_micro_batch(self, cache_id): "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d." % (batch_size, self.micro_batch_size, self.accumulate_steps)) - data = [ - input[begin:end, :].clone().detach() for input in inputs[0] - ] - self.caches['inputs'][cache_id] = tuple(data) + data = [input[begin:end, :].detach() for input in inputs[0]] + return tuple(data) else: batch_size = inputs[0].shape[0] assert self.micro_batch_size * self.accumulate_steps == batch_size - self.caches['inputs'][cache_id] = inputs[0][begin:end, :].clone( - ).detach() + return inputs[0][begin:end, :].detach() elif self.is_last_stage: assert len(inputs) == 2, "length of input should be 2" - if self.use_model_parallel: - inputs[1] = self._broadcast_data(inputs[1]) if isinstance(inputs[1], tuple): batch_size = inputs[1][0].shape[0] assert self.micro_batch_size * self.accumulate_steps == batch_size - data = [ - input[begin:end, :].clone().detach() for input in inputs[1] - ] - self.caches['labels'][cache_id] = tuple(data) + data = [input[begin:end, :].detach() for input in inputs[1]] + return tuple(data) else: batch_size = inputs[1].shape[0] assert self.micro_batch_size * self.accumulate_steps == batch_size - self.caches['labels'][cache_id] = inputs[1][begin:end, :].clone( - ).detach() + return inputs[1][begin:end, :].detach() else: # No data input is required for other stages inputs = None - def _send_meta(self, data, peer): - if isinstance(data, paddle.Tensor): - tensor_type = paddle.to_tensor([0]) - # send tensor type - p2p.send(tensor_type, self.next_stage_id) - - # send len(shape) - dims = paddle.to_tensor(len(data.shape)) - p2p.send(dims, self.next_stage_id) - - # send shape - shape = paddle.to_tensor(data.shape) - p2p.send(shape, self.next_stage_id) - - # send dtype - dtype = paddle.to_tensor(paddle_2_number(data.dtype)) - p2p.send(dtype, self.next_stage_id) - - elif isinstance(data, tuple): - tensor_type = paddle.to_tensor([1]) - p2p.send(tensor_type, self.next_stage_id) - - nums = paddle.to_tensor(len(data)) - p2p.send(nums, self.next_stage_id) - - for idx, d in enumerate(data): - assert isinstance(d, paddle.Tensor) - # send len(shape) - dims = paddle.to_tensor(len(d.shape)) - p2p.send(dims, self.next_stage_id) - - # send shape - shape = paddle.to_tensor(d.shape) - p2p.send(shape, self.next_stage_id) - - # send dtype - dtype = paddle.to_tensor(paddle_2_number(d.dtype)) - p2p.send(dtype, self.next_stage_id) - - def _recv_meta(self, peer): - tensor_type = paddle.to_tensor([0]) - p2p.recv(tensor_type, self.prev_stage_id) - - tensor_type = tensor_type.item() - - if tensor_type == 0: - # recv len(shape) - dims = paddle.to_tensor([0]) - p2p.recv(dims, self.prev_stage_id) - - dims = dims.item() - - # recv shape - shape = paddle.to_tensor([0] * dims) - p2p.recv(shape, self.prev_stage_id) - - shape = shape.numpy().tolist() - - # recv dtype - dtype = paddle.to_tensor([0]) - p2p.recv(dtype, self.prev_stage_id) - - return self._allocate_cache( - shape, dtype=number_2_dtype(dtype.item()), num_caches=1)[0] - elif tensor_type == 1: - num = paddle.to_tensor([0]) - p2p.recv(num, self.prev_stage_id) - num = num.item() - shapes = [] - dtypes = [] - for i in range(num): - # recv len(shape) - dims = paddle.to_tensor([0]) - p2p.recv(dims, self.prev_stage_id) - - # recv shape - dims = dims.item() - shape = paddle.to_tensor([0] * dims) - p2p.recv(shape, self.prev_stage_id) - shapes.append(shape.numpy().tolist()) - - # recv dtype - dtype = paddle.to_tensor([0]) - p2p.recv(dtype, self.prev_stage_id) - dtypes.append(number_2_dtype(dtype.item())) - - caches = self._allocate_caches(shapes, dtypes, num_caches=1)[0] - caches = tuple(caches) - return caches - - def _is_valid_send_recv(self, tensor): - tensor_numel = np.prod(tensor.shape) - assert tensor_numel != 0, "can't send/recv zero element" - return tensor_numel % self.mp_degree == 0 - - def _send_activations(self, cache_id): - outputs = self.caches['outputs'][cache_id] - - if self.send_meta: - self.send_meta = False - self._send_meta(outputs, self.next_stage_id) - - if isinstance(outputs, paddle.Tensor): - if self.is_pipe_partitioned and self._is_valid_send_recv(outputs): - p2p.send_partial( - outputs.detach(), - self.next_stage_id, - mp_degree=self.mp_degree, - mp_rank=self.mp_rank) - else: - p2p.send(outputs.detach(), self.next_stage_id) - - elif isinstance(outputs, tuple): - for output in outputs: - if self.is_pipe_partitioned and self._is_valid_send_recv( - output): - p2p.send_partial( - output.detach(), - self.next_stage_id, - mp_degree=self.mp_degree, - mp_rank=self.mp_rank) - else: - p2p.send(output.detach(), self.next_stage_id) - - def _send_gradients(self, cache_id): - inputs = self.caches['inputs'][cache_id] - if isinstance(inputs, paddle.Tensor): - assert inputs.grad is not None - if self.is_pipe_partitioned and self._is_valid_send_recv( - inputs.grad): - grad = p2p.send_partial( - inputs.grad, - self.prev_stage_id, - mp_degree=self.mp_degree, - mp_rank=self.mp_rank) - else: - p2p.send(inputs.grad, self.prev_stage_id) - else: - for idx, d in enumerate(inputs): - # Skip tensors that will not produce a grad - if not is_float_tensor(d): - assert d.grad is None - continue - - if self.is_pipe_partitioned and self._is_valid_send_recv( - d.grad): - grad = p2p.send_partial( - d.grad, - self.prev_stage_id, - mp_degree=self.mp_degree, - mp_rank=self.mp_rank) - else: - p2p.send(d.grad, self.prev_stage_id) - - self.caches['inputs'][cache_id] = None - - def _recv_activations(self, cache_id): - inputs = None - if self.recv_cache is None: - self.recv_cache = self._recv_meta(self.prev_stage_id) - - if isinstance(self.recv_cache, paddle.Tensor): - if self.is_pipe_partitioned and self._is_valid_send_recv( - self.recv_cache): - p2p.recv_partial(self.recv_cache, self.prev_stage_id, - self.mp_degree, self.mp_rank) - p2p.partial_allgather_operator( - self.recv_cache, - mp_ranks=self.mp_degree, - mp_rank_id=self.mp_rank, - group=self._hcg.get_model_parallel_group(), - use_calc_stream=True) - else: - p2p.recv(self.recv_cache, self.prev_stage_id) - - inputs = self.recv_cache.clone().detach() - inputs.stop_gradient = not is_float_tensor(inputs) - + def _reduce_final_loss(self): + if self.is_last_stage: + assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss" + loss = self.total_loss.detach() + paddle.distributed.broadcast( + loss, + src=self.global_rank, + use_calc_stream=True, + group=self.pp_group) else: - assert isinstance(self.recv_cache, tuple) - inputs = [None] * len(self.recv_cache) - for idx, d in enumerate(self.recv_cache): - if self.is_pipe_partitioned and self._is_valid_send_recv(d): - assert isinstance(d, paddle.Tensor) - p2p.recv_partial(d, self.prev_stage_id, self.mp_degree, - self.mp_rank) - p2p.partial_allgather_operator( - d, - mp_ranks=self.mp_degree, - mp_rank_id=self.mp_rank, - group=self._hcg.get_model_parallel_group(), - use_calc_stream=True) - else: - assert isinstance(d, paddle.Tensor) - p2p.recv(d, self.prev_stage_id) - inputs[idx] = d.clone().detach() - - inputs = tuple(inputs) - - for d in inputs: - d.stop_gradient = not is_float_tensor(d) - - self.caches['inputs'][cache_id] = inputs - - def _recv_gradients(self, cache_id): - outputs = self.caches['outputs'][cache_id] - if self.grad_tensors is None: - if isinstance(outputs, paddle.Tensor): - s = list(outputs.shape) - dtype = get_tensor_dtype(outputs.dtype) - self.grad_tensors = self._allocate_cache( - s, dtype, num_caches=1)[0] - else: - sizes = [list(d.shape) for d in outputs if is_float_tensor(d)] - dtypes = [ - get_tensor_dtype(d.dtype) for d in outputs - if is_float_tensor(d) - ] - self.grad_tensors = self._allocate_caches( - sizes, dtypes, num_caches=1)[0] - - if isinstance(self.grad_tensors, paddle.Tensor): - if self.is_pipe_partitioned and self._is_valid_send_recv( - self.grad_tensors): - p2p.recv_partial(self.grad_tensors, self.next_stage_id, - self.mp_degree, self.mp_rank) - p2p.partial_allgather_operator( - self.grad_tensors, - mp_ranks=self.mp_degree, - mp_rank_id=self.mp_rank, - group=self._hcg.get_model_parallel_group(), - use_calc_stream=True) - else: - p2p.recv(self.grad_tensors, self.next_stage_id) + loss = paddle.zeros(shape=[1], dtype="float32") + paddle.distributed.broadcast( + loss, + src=self._hcg.get_rank_from_stage(self.num_stages - 1), + use_calc_stream=True, + group=self.pp_group) + return loss - else: - assert isinstance(outputs, tuple) - for d in self.grad_tensors: - if self.is_pipe_partitioned and self._is_valid_send_recv(d): - p2p.recv_partial(d, self.next_stage_id, self.mp_degree, - self.mp_rank) - p2p.partial_allgather_operator( - d, - mp_ranks=self.mp_degree, - mp_rank_id=self.mp_rank, - group=self._hcg.get_model_parallel_group(), - use_calc_stream=True) - else: - p2p.recv(d, self.next_stage_id) - - def _step(self): + def _optimizer_step(self): if self.scaler: self.scaler.minimize(self.optimizer, self.train_loss) else: self.optimizer.step() + self.optimizer.clear_grad() if self.lr_scheduler: self.lr_scheduler.step() - - def _clear_grads(self, inputs): - if isinstance(inputs, paddle.Tensor): - if inputs.grad is not None: - inputs.clear_gradient() - else: - for d in inputs: - if d.grad is not None: - d.clear_gradient() - - def _allocate_zeros(self, shape, dtype): - return paddle.zeros(shape, dtype) - - def _allocate_cache(self, shape, dtype, num_caches=-1): - caches = [] - if num_caches == -1: - num_caches = self.num_caches - for count in range(num_caches): - caches.append(self._allocate_zeros(shape, dtype)) - return caches - - def _allocate_caches(self, shapes, dtypes, num_caches=-1): - caches = [] - if num_caches == -1: - num_caches = self.num_caches - for count in range(num_caches): - cache = [] - for shape, dtype in zip(shapes, dtypes): - cache.append(self._allocate_zeros(shape, dtype)) - caches.append(cache) - return caches - - def save_state_dict(self, model_path): - state_dict = self._layers.state_dict() - paddle.save(state_dict, model_path) - - def load_state_dict(self, model_path): - state_dict = paddle.load(self.model_path) - self._layers.set_state_dict(state_dict) - - def forward(self, *inputs, **kwargs): - raise RuntimeError("Call train_batch for pipeline instead of forward.") diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index 44090be94f1a7..e533b2ef3f7a3 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -13,131 +13,388 @@ # limitations under the License. import paddle +from .utils import paddle_2_number, number_2_dtype +from ...utils.log_util import logger -_groups = None _hcg = None def initialize_p2p_groups(hcg): - global _groups, _hcg - _groups = [ - paddle.distributed.new_group(ranks=group) - for group in hcg.get_p2p_groups() - ] + global _hcg _hcg = hcg + send_next_group, send_prev_group, recv_next_group, recv_prev_group = _hcg.get_p2p_groups( + ) + debug_str = "P2pInfo: send_next_group: %s, send_prev_group: %s, " \ + "recv_next_group: %s, recv_prev_group: %s" % (repr(send_next_group), + repr(send_prev_group),repr(recv_next_group), repr(recv_prev_group)) + logger.info(debug_str) -def _is_valid_communciate(src_stage, dest_stage): - first_stage = 0 - last_stage = _hcg.get_pipe_parallel_world_size() - 1 - assert abs(src_stage-dest_stage) == 1 or \ - (src_stage == first_stage and dest_stage == last_stage) or \ - (src_stage == last_stage and dest_stage == first_stage) +class SendRecvMeta: + """Mainly used to help p2p communication context information""" -def partial_send_operator(tensor, - dst=0, - mp_ranks=1, - mp_rank_id=0, - group=None, - use_calc_stream=True): + def __init__(self): + self.send_shape_message = None + self.send_dtype_message = None + self.recv_shape_message = None + self.recv_dtype_message = None + + self.has_send_meta = False + self.has_recv_meta = False + + def _recv_shape_dtype(self, group): + # recv len(shape) + dims = paddle.to_tensor([0]) + paddle.distributed.recv(dims, src=0, group=group) + dims = dims.item() + + # recv shape + shape = paddle.to_tensor([0] * dims) + paddle.distributed.recv(shape, src=0, group=group) + + # recv dtype + dtype = paddle.to_tensor([0]) + paddle.distributed.recv(dtype, src=0, group=group) + return shape.numpy().tolist(), dtype.item() + + def recv_meta(self, group): + tensor_type = paddle.to_tensor([0]) + paddle.distributed.recv(tensor_type, src=0, group=group) + tensor_type = tensor_type.item() + + if tensor_type == 0: + shape, dtype = self._recv_shape_dtype(group) + self.recv_shape_message = shape + self.recv_dtype_message = dtype + + elif tensor_type == 1: + num = paddle.to_tensor([0]) + paddle.distributed.recv(num, src=0, group=group) + num = num.item() + shapes = [] + dtypes = [] + for i in range(num): + shape, dtype = self._recv_shape_dtype(group) + shapes.append(shape) + dtypes.append(dtype) + + self.recv_shape_message = tuple(shapes) + self.recv_dtype_message = tuple(dtypes) + + def _send_dims_shape_dtype(self, tensor, group): + # send len(shape) + dims = paddle.to_tensor(len(tensor.shape)) + paddle.distributed.send(dims, dst=1, group=group) + + # send shape + shape = paddle.to_tensor(tensor.shape) + paddle.distributed.send(shape, dst=1, group=group) + + # send dtype + dtype = paddle.to_tensor(paddle_2_number(tensor.dtype)) + paddle.distributed.send(dtype, dst=1, group=group) + + def send_meta(self, tensor, group): + if isinstance(tensor, paddle.Tensor): + tensor_type = paddle.to_tensor([0]) + # send tensor type + paddle.distributed.send(tensor_type, dst=1, group=group) + + self._send_dims_shape_dtype(tensor, group) + elif isinstance(tensor, tuple): + tensor_type = paddle.to_tensor([1]) + # send tensor type + paddle.distributed.send(tensor_type, dst=1, group=group) + + nums = paddle.to_tensor(len(tensor)) + paddle.distributed.send(nums, dst=1, group=group) + + for d in tensor: + assert isinstance(d, paddle.Tensor) + self._send_dims_shape_dtype(d, group=group) + + def set_send_message(self, tensor): + if isinstance(tensor, paddle.Tensor): + self.send_shape_message = tensor.shape + self.send_dtype_message = paddle_2_number(tensor.dtype) + elif isinstance(tensor, tuple): + self.send_shape_message = tuple( + [d.shape for d in tensor if not d.stop_gradient]) + self.send_dtype_message = tuple( + [paddle_2_number(d.dtype) for d in tensor]) + + +_send_recv_meta = SendRecvMeta() + + +def send_partial(tensor, + dst=0, + nranks=1, + rank_id=0, + group=None, + use_calc_stream=True): if group is not None and not group.is_member(): return ring_id = 0 if group is None else group.id return paddle.fluid.core.ops.partial_send( tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', - dst, 'num', mp_ranks, 'id', mp_rank_id) + dst, 'num', nranks, 'id', rank_id) -def partial_recv_operator(tensor, - src=0, - mp_ranks=1, - mp_rank_id=0, - group=None, - use_calc_stream=True): - +def recv_partial(tensor, + src=0, + nranks=1, + rank_id=0, + group=None, + use_calc_stream=True): if group is not None and not group.is_member(): return ring_id = 0 if group is None else group.id - return paddle.fluid.core.ops.partial_recv( + paddle.fluid.core.ops.partial_recv( tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'peer', - src, 'num', mp_ranks, 'id', mp_rank_id, 'dtype', tensor.dtype, - 'out_shape', tensor.shape) + src, 'num', nranks, 'id', rank_id, 'dtype', tensor.dtype, 'out_shape', + tensor.shape) -def partial_allgather_operator(tensor, - mp_ranks=1, - mp_rank_id=0, - group=None, - use_calc_stream=True): +def allgather_partial(tensor, + nranks=1, + rank_id=0, + group=None, + use_calc_stream=True): + if nranks == 1: + return tensor if group is not None and not group.is_member(): return ring_id = 0 if group is None else group.id return paddle.fluid.core.ops.partial_allgather_( tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, - 'nranks', mp_ranks, 'rank', mp_rank_id) - - -def send(tensor, dest_stage): - global _groups, _hcg - src_stage = _hcg.get_stage_id() - _is_valid_communciate(src_stage, dest_stage) - group = _get_send_recv_group(src_stage, dest_stage) - return paddle.distributed.send( - tensor, dst=1 if dest_stage > src_stage else 0, group=group) - - -def recv(tensor, src_stage): - global _groups, _hcg - dest_stage = _hcg.get_stage_id() - - _is_valid_communciate(src_stage, dest_stage) - group = _get_send_recv_group(src_stage, dest_stage) - return paddle.distributed.recv( - tensor, src=0 if dest_stage > src_stage else 1, group=group) - - -def send_partial(tensor, dest_stage, mp_degree, mp_rank): - global _groups, _hcg - src_stage = _hcg.get_stage_id() - _is_valid_communciate(src_stage, dest_stage) - group = _get_send_recv_group(src_stage, dest_stage) - return partial_send_operator( - tensor, - dst=1 if dest_stage > src_stage else 0, - mp_ranks=mp_degree, - mp_rank_id=mp_rank, - group=group) - - -def recv_partial(tensor, src_stage, mp_degree, mp_rank): - global _groups, _hcg - dest_stage = _hcg.get_stage_id() - - _is_valid_communciate(src_stage, dest_stage) - group = _get_send_recv_group(src_stage, dest_stage) - return partial_recv_operator( - tensor, - src=0 if dest_stage > src_stage else 1, - mp_ranks=mp_degree, - mp_rank_id=mp_rank, - group=group) - - -def _get_send_recv_group(src_stage, dest_stage): - global _groups, _hcg - stage_id = None - first_stage = 0 - last_stage = _hcg.get_pipe_parallel_world_size() - 1 - if (src_stage == first_stage and dest_stage == last_stage) or \ - (dest_stage == first_stage and src_stage == last_stage): - stage_id = last_stage - elif src_stage > dest_stage: - stage_id = dest_stage + 'nranks', nranks, 'rank', rank_id) + + +def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): + global _hcg + + tensor_recv_prev = None + tensor_recv_next = None + + # send / recv message + recv_shape_msg = _send_recv_meta.recv_shape_message + recv_dtype_msg = _send_recv_meta.recv_dtype_message + send_shape_msg = _send_recv_meta.send_shape_message + send_dtype_msg = _send_recv_meta.send_dtype_message + + # model parallel message + mp_group = _hcg.get_model_parallel_group() + mp_degree = _hcg.get_model_parallel_world_size() + mp_rank = _hcg.get_model_parallel_rank() + + if recv_prev: + if isinstance(recv_shape_msg, tuple): + tensor_recv_prev = [] + for idx, shape in enumerate(recv_shape_msg): + tensor_recv_prev.append( + paddle.empty( + shape=shape, dtype=number_2_dtype(recv_dtype_msg[idx]))) + tensor_recv_prev = tuple(tensor_recv_prev) + else: + tensor_recv_prev = paddle.empty( + shape=recv_shape_msg, dtype=number_2_dtype(recv_dtype_msg)) + + if recv_next: + if isinstance(send_shape_msg, tuple): + tensor_recv_next = [] + for idx, shape in enumerate(send_shape_msg): + tensor_recv_next.append( + paddle.empty( + shape=shape, dtype=number_2_dtype(send_dtype_msg[idx]))) + tensor_recv_next = tuple(tensor_recv_next) + else: + tensor_recv_next = paddle.empty( + shape=send_shape_msg, dtype=number_2_dtype(send_dtype_msg)) + + # start to p2p communicate + if tensor_send_prev is not None: + if isinstance(tensor_send_prev, tuple): + for d in tensor_send_prev: + paddle.distributed.wait(d, use_calc_stream=True) + send_partial( + d, + dst=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.send_prev_group, + use_calc_stream=False) + else: + paddle.distributed.wait(tensor_send_prev, use_calc_stream=True) + send_partial( + tensor_send_prev, + dst=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.send_prev_group, + use_calc_stream=False) + + if tensor_recv_prev is not None: + if isinstance(tensor_recv_prev, tuple): + for d in tensor_recv_prev: + recv_partial( + d, + src=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_prev_group, + use_calc_stream=True) + allgather_partial( + d, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True) + else: + recv_partial( + tensor_recv_prev, + src=0, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_prev_group, + use_calc_stream=True) + allgather_partial( + tensor_recv_prev, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True) + + if tensor_send_next is not None: + if isinstance(tensor_send_next, tuple): + for d in tensor_send_next: + paddle.distributed.wait(d, use_calc_stream=True) + send_partial( + d, + dst=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.send_next_group, + use_calc_stream=False) + else: + paddle.distributed.wait(tensor_send_next, use_calc_stream=True) + send_partial( + tensor_send_next, + dst=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.send_next_group, + use_calc_stream=False) + + if tensor_recv_next is not None: + if isinstance(tensor_recv_next, tuple): + for d in tensor_recv_next: + recv_partial( + d, + src=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_next_group, + use_calc_stream=True) + allgather_partial( + d, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True) + + else: + recv_partial( + tensor_recv_next, + src=1, + nranks=mp_degree, + rank_id=mp_rank, + group=_hcg.recv_next_group, + use_calc_stream=True) + + allgather_partial( + tensor_recv_next, + nranks=mp_degree, + rank_id=mp_rank, + group=mp_group, + use_calc_stream=True) + return tensor_recv_prev, tensor_recv_next + + +def recv_forward(): + if _hcg.is_first_stage: + input_tensor = None + else: + if not _send_recv_meta.has_recv_meta: + _send_recv_meta.recv_meta(_hcg.recv_prev_group) + _send_recv_meta.has_recv_meta = True + + input_tensor, _ = _p2p_helper( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=True, + recv_next=False) + return input_tensor + + +def recv_backward(): + if _hcg.is_last_stage: + output_tensor_grad = None + else: + _, output_tensor_grad = _p2p_helper( + tensor_send_next=None, + tensor_send_prev=None, + recv_prev=False, + recv_next=True) + return output_tensor_grad + + +def send_forward(output_tensor): + if not _hcg.is_last_stage: + if not _send_recv_meta.has_send_meta: + _send_recv_meta.set_send_message(output_tensor) + _send_recv_meta.send_meta(output_tensor, _hcg.send_next_group) + _send_recv_meta.has_send_meta = True + + _p2p_helper( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=False, + recv_next=False) + + +def send_backward(input_tensor_grad): + if not _hcg.is_first_stage: + _p2p_helper( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=False, + recv_next=False) + + +def send_forward_recv_backward(output_tensor): + if _hcg.is_last_stage: + output_tensor_grad = None + else: + _, output_tensor_grad = _p2p_helper( + tensor_send_next=output_tensor, + tensor_send_prev=None, + recv_prev=False, + recv_next=True) + return output_tensor_grad + + +def send_backward_recv_forward(input_tensor_grad): + if _hcg.is_first_stage: + input_tensor = None else: - stage_id = src_stage - group_id = _hcg.get_rank_from_stage(stage_id=stage_id) - return _groups[group_id] + input_tensor, _ = _p2p_helper( + tensor_send_next=None, + tensor_send_prev=input_tensor_grad, + recv_prev=True, + recv_next=False) + return input_tensor diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py new file mode 100644 index 0000000000000..84971f2bc3557 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py @@ -0,0 +1,177 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest +import paddle +import numpy as np +import random +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from paddle.fluid import layers +import paddle.nn.functional as F +from paddle.distributed.fleet.meta_parallel import PipelineLayer, LayerDesc +from paddle.fluid.dygraph.layers import Layer +import paddle.nn as nn + + +def set_random_seed(seed, dp_id, rank_id): + """Set random seed for reproducability.""" + random.seed(seed) + np.random.seed(seed + dp_id) + paddle.seed(seed + dp_id) + + +batch_size = 8 +length = 8 +micro_batch_size = 2 +vocab_size = 128 +hidden_size = 16 +d_model = hidden_size +dim_feedforward = 4 * d_model + + +class EmbeddingNet(Layer): + def __init__(self): + super(EmbeddingNet, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + self.position_embeddings = nn.Embedding(vocab_size, hidden_size) + + def forward(self, x): + attention_mask = paddle.tensor.triu( + (paddle.ones( + (length, length), dtype="float32") * -1e9), 1) + attention_mask.stop_gradient = True + w_emb = self.word_embeddings(x) + p_emb = self.position_embeddings(x) + w_emb = w_emb + p_emb + + # need to fix bug of backward() + return w_emb, attention_mask + + +class TransformerNet(Layer): + def __init__(self): + super(TransformerNet, self).__init__() + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.q_proj = nn.Linear(d_model, d_model) + self.k_proj = nn.Linear(d_model, d_model) + self.v_proj = nn.Linear(d_model, d_model) + + self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5) + + def forward(self, x, mask): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_model**-0.5) + + weights = F.softmax(product + mask) + weights = F.dropout(weights, 0.2) + tgt = layers.matmul(weights, v) + residual = tgt + tgt = self.norm1(tgt) + tgt = residual + tgt + + out = self.linear2(F.gelu(self.linear1(tgt), approximate=True)) + return out + + +class EmbeddingPipe(EmbeddingNet): + def forward(self, x): + return super().forward(x) + + +class TransformerNetPipe(TransformerNet): + def forward(self, args): + x, mask = args[0], args[1] + + output = super().forward(x, mask) + output = output + mask.stop_gradient = True + return output, mask + + +class CriterionPipe(Layer): + def __init__(self): + super(CriterionPipe, self).__init__() + + def forward(self, out, label): + loss = out.mean() + return loss + + +class ModelPipe(PipelineLayer): + def __init__(self, topology): + self.descs = [] + self.descs.append(LayerDesc(EmbeddingPipe)) + + for x in range(5): + self.descs.append(LayerDesc(TransformerNetPipe)) + + self.descs.append(lambda x: x[0]) + + super().__init__( + layers=self.descs, loss_fn=CriterionPipe(), topology=topology) + + +class TestDistPPTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = { + "accumulate_steps": batch_size // micro_batch_size, + "micro_batch_size": micro_batch_size + } + fleet.init(is_collective=True, strategy=strategy) + + def test_pp_model(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + pp_id = hcg.get_stage_id() + rank_id = dist.get_rank() + topology = hcg.topology() + set_random_seed(1024, dp_id, rank_id) + + model = ModelPipe(topology) + scheduler = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True) + optimizer = paddle.optimizer.SGD(learning_rate=scheduler, + parameters=model.parameters()) + + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + + for step_id in range(5): + x_data = np.random.randint(0, vocab_size, size=[batch_size, length]) + x = paddle.to_tensor(x_data) + x.stop_gradient = True + loss = model.train_batch([x, x], optimizer, scheduler) + # TODO(shenliang03) add utest for loss + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py index 9f534381c98ab..62e781678c9fc 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_dygraph_pipeline_parallel.py @@ -33,6 +33,9 @@ def test_hybrid_parallel_shared_weight(self): def test_pipeline_parallel(self): self.run_mnist_2gpu('hybrid_parallel_pp_amp.py') + def test_hybrid_parallel_transformer(self): + self.run_mnist_2gpu('hybrid_parallel_pp_transformer.py') + if __name__ == "__main__": unittest.main()