Skip to content

Commit

Permalink
[HybridParallel]Add SharedLayerDesc for PipelineParallel (#33578)
Browse files Browse the repository at this point in the history
* add pplayer

* add sharedlayerdesc
  • Loading branch information
ForFishes committed Jun 16, 2021
1 parent 07197fb commit 294dfd2
Show file tree
Hide file tree
Showing 9 changed files with 358 additions and 7 deletions.
4 changes: 3 additions & 1 deletion python/paddle/distributed/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,9 @@ def new_group(ranks=None, backend=None):

# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by cross-creation of new_group
tmp = fill_constant([0], dtype="int32", value="1")
tmp = paddle.to_tensor(
[1], dtype="int32") if in_dygraph_mode() else fill_constant(
[0], dtype="int32", value="1")
paddle.distributed.all_reduce(tmp, use_calc_stream=True)
paddle.distributed.wait(tmp)
return gp
Expand Down
12 changes: 8 additions & 4 deletions python/paddle/distributed/fleet/base/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ def get_comm_list(self, axis_name):

return all_result

def get_rank_from_stage(self, global_rank, **kwargs):
coord = self.get_coord(global_rank)
tf = coord._replace(**kwargs)._asdict()
return self.get_rank(**tf)


class HybridCommunicateGroup(object):
def __init__(self, topology):
Expand Down Expand Up @@ -254,7 +259,6 @@ def get_pipe_parallel_group(self):
def get_check_parallel_group(self):
return self._check_comm_group

def get_rank_from_stage(self, stage_id):
coord = self._topo.get_coord(self.global_rank)
tf = coord._replace(pipe=stage_id)._asdict()
return self._topo.get_rank(**tf)
def get_rank_from_stage(self, stage_id, **kwargs):
return self._topo.get_rank_from_stage(
self.global_rank, pipe=stage_id, **kwargs)
1 change: 1 addition & 0 deletions python/paddle/distributed/fleet/meta_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .parallel_layers import RowParallelLinear # noqa: F401
from .parallel_layers import ParallelCrossEntropy # noqa: F401
from .parallel_layers import LayerDesc # noqa: F401
from .parallel_layers import SharedLayerDesc # noqa: F401
from .parallel_layers import PipelineLayer # noqa: F401
from .parallel_layers import RNGStatesTracker # noqa: F401
from .parallel_layers import model_parallel_random_seed # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .mp_layers import RowParallelLinear # noqa: F401
from .mp_layers import ParallelCrossEntropy # noqa: F401
from .pp_layers import LayerDesc # noqa: F401
from .pp_layers import SharedLayerDesc # noqa: F401
from .pp_layers import PipelineLayer # noqa: F401
from .random import RNGStatesTracker # noqa: F401
from .random import model_parallel_random_seed # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import paddle
from paddle.fluid.dygraph.layers import Layer
from ...utils.log_util import logger, layer_to_str
from functools import partial

__all__ = []

Expand Down Expand Up @@ -58,6 +59,20 @@ def __repr__(self):
**self.kwargs)


class SharedLayerDesc(LayerDesc):
def __init__(self,
key,
layer_func,
forward_func=None,
shared_weight_attr='weight',
*inputs,
**kwargs):
super(SharedLayerDesc, self).__init__(layer_func, *inputs, **kwargs)
self.layer_name = key
self.forward_func = forward_func
self.shared_weight_attr = shared_weight_attr


class PipelineLayer(Layer):
def __init__(self,
layers,
Expand Down Expand Up @@ -104,11 +119,86 @@ def __init__(self,
self._start_pos = 0
self._end_pos = self._num_layers - 1
self._segment_network(seg_method)
self.shared_layers = paddle.nn.LayerDict()
self.shared_weight_attrs = {}

# construct layer
self.run_function = []
self._build_layer()

self.shared_comm = self._construct_shared_comm()
self._synchronize_shared_weights()

def get_stage_from_index(self, layer_idx):
assert 0 <= layer_idx < self._num_layers, "layer_idx is out of bound"
for stage in range(self._topo.get_dim('pipe')):
if self.segment_parts[stage] <= layer_idx < self.segment_parts[stage
+ 1]:
return stage

def _construct_shared_comm(self):
shared_comm = {}
if self._topo.get_dim("pipe") == 1:
return

layers_desc = self._layers_desc
shared_layer_names = set(
s.layer_name for s in layers_desc if isinstance(s, SharedLayerDesc))
for key in shared_layer_names:
shared_layers = []
for idx, layer in enumerate(layers_desc):
if isinstance(layer,
SharedLayerDesc) and layer.layer_name == key:
shared_layers.append(idx)

shared_stages = set(
self.get_stage_from_index(idx) for idx in shared_layers)
self._dp_degree = self._topo.get_dim('data')
self._mp_degree = self._topo.get_dim('model')

shared_ranks = []
for dp in range(self._dp_degree):
for mp in range(self._mp_degree):
shared_ranks = []
for s in sorted(shared_stages):
shared_ranks.append(
self._topo.get_rank_from_stage(
self.global_rank, pipe=s, data=dp, model=mp))

group = paddle.distributed.new_group(ranks=shared_ranks)
if self.global_rank in shared_ranks:
assert key in self.shared_layers
if key in self.shared_layers:
shared_comm[key] = {
'ranks': shared_ranks,
'group': group,
'weight_attr': self.shared_weight_attrs[key],
'layer': self.shared_layers[key],
}
return shared_comm

def _synchronize_shared_weights(self):
for key, comm in self.shared_comm.items():
with paddle.framework.no_grad():
paddle.distributed.broadcast(
getattr(comm['layer'], comm['weight_attr']),
src=min(comm['ranks']),
group=comm['group'])

def allreduce_shared_weight_gradients(self):
for key, comm in self.shared_comm.items():
param = getattr(self.shared_layers[key], comm['weight_attr'])
# need use trace_op to allreduce weight
with paddle.framework.no_grad():
paddle.fluid.framework._dygraph_tracer().trace_op(
type="c_allreduce_sum",
inputs={'X': param._grad_ivar()},
outputs={'Out': param._grad_ivar()},
attrs={
'ring_id': comm['group'].id,
'use_calc_stream': True
})

def _segment_network(self, seg_method):
logger.info("start segment network..")
seg = SegmentLayers(
Expand Down Expand Up @@ -142,6 +232,21 @@ def _build_layer(self):
if isinstance(layer, Layer):
self.run_function.append(layer)
self.add_sublayer(str(layer_index), layer)
elif isinstance(layer, SharedLayerDesc):
if layer.layer_name not in self.shared_layers:
self.shared_layers[layer.layer_name] = layer.build_layer()
self.shared_weight_attrs[
layer.layer_name] = layer.shared_weight_attr

if layer.forward_func is None:
self.run_function.append(self.shared_layers[
layer.layer_name])

else:
self.run_function.append(
partial(layer.forward_func, self.shared_layers[
layer.layer_name]))

elif isinstance(layer, LayerDesc):
model = layer.build_layer()
self.run_function.append(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def train_batch(self, data, optimizer, lr_scheduler=None):
self._backward(cache_id=backward_steps)
backward_steps += 1

self._layers.allreduce_shared_weight_gradients()

# optimizer
self._step()
self.train_loss = self._reduce_final_loss()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,8 @@ def test_parallel_embedding(self):
np.testing.assert_allclose(loss_a.numpy(), loss_b.numpy())

def test_parallel_cross_entropy(self):
batch_size = 2
seq_length = 1
batch_size = 8
seq_length = 16
class_size_per_card = 2
vocab_size = class_size_per_card * self.model_parallel_size
seed = 1025
Expand Down
Loading

0 comments on commit 294dfd2

Please sign in to comment.