Skip to content

Commit

Permalink
[HybridParallel]Support 1f1b for PipelineParallel (#34483)
Browse files Browse the repository at this point in the history
* support 1f1b for pipeline

* add utest

* add send_partial/recv_partial

* support amp for pp

* fix logger
  • Loading branch information
ForFishes committed Aug 2, 2021
1 parent 3b5fc2a commit 9e0bb91
Show file tree
Hide file tree
Showing 5 changed files with 716 additions and 561 deletions.
63 changes: 42 additions & 21 deletions python/paddle/distributed/fleet/base/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def __init__(self, topology):
self.is_first_stage = (self.stage_id == 0)
self.is_last_stage = (self.stage_id == (self._pp_degree - 1))

# create p2p_groups
if self._pp_degree > 1:
self._set_p2p_group()

debug_str = "HybridParallelInfo: rank_id: %d, mp_degree: %d, " \
"sharding_degree: %d, pp_degree: %d, dp_degree: %d" % (self.global_rank, self._mp_degree,
self._sharding_degree, self._pp_degree, self._dp_degree)
Expand All @@ -164,27 +168,9 @@ def __init__(self, topology):
self._dp_group, self._check_group)
logger.info(debug_str)

# create p2p_groups and no new group
self._p2p_groups = self._build_p2p_lists()

global _HYBRID_PARALLEL_GROUP
_HYBRID_PARALLEL_GROUP = self

def _build_p2p_lists(self):
comm_lists = self._topo.get_comm_list('pipe')
p2p_lists = []
for rank in range(self.nranks):
for comm_ranks in comm_lists:
assert len(comm_ranks) == self._pp_degree
if rank in comm_ranks:
idx = comm_ranks.index(rank)
next_rank = comm_ranks[(idx + 1) % self._pp_degree]
p2p_lists.append([rank, next_rank])
break
assert len(
p2p_lists) == self.nranks, "len(p2p_lists) should be equal nranks"
return p2p_lists

def get_parallel_mode(self):
# there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel
# NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and
Expand Down Expand Up @@ -236,6 +222,41 @@ def _set_check_group(self, parallel_method="data"):

return parallel_group, parallel_comm_group

def _set_p2p_group(self):
comm_lists = self._topo.get_comm_list('pipe')

self.send_next_group = None
self.send_prev_group = None
self.recv_next_group = None
self.recv_prev_group = None

for comm_ranks in comm_lists:
assert len(comm_ranks) == self._pp_degree
for idx, rank in enumerate(comm_ranks):
curr_rank = rank
next_rank = comm_ranks[(idx + 1) % self._pp_degree]
prev_rank = comm_ranks[(idx - 1) % self._pp_degree]

next_group = paddle.distributed.new_group(
ranks=[curr_rank, next_rank])
if self.global_rank == curr_rank:
self.send_next_group = next_group
elif self.global_rank == next_rank:
self.recv_prev_group = next_group

prev_group = paddle.distributed.new_group(
ranks=[prev_rank, curr_rank])

if self.global_rank == curr_rank:
self.send_prev_group = prev_group
elif self.global_rank == prev_rank:
self.recv_next_group = prev_group

assert self.send_next_group is not None
assert self.send_prev_group is not None
assert self.recv_next_group is not None
assert self.recv_prev_group is not None

def topology(self):
return self._topo

Expand Down Expand Up @@ -287,6 +308,9 @@ def get_pipe_parallel_world_size(self):
def get_pipe_parallel_group(self):
return self._pp_comm_group

def get_p2p_groups(self):
return self.send_next_group, self.send_prev_group, self.recv_next_group, self.recv_prev_group

# sharding parallel message:
def _get_sharding_parallel_id(self):
return self._topo.get_coord(self.global_rank).sharding
Expand All @@ -304,9 +328,6 @@ def get_sharding_parallel_group_src_rank(self):
# TODO should the src rank related to the shard rank for each parameter ?
return self._sharding_comm_group.ranks[0]

def get_p2p_groups(self):
return self._p2p_groups

# check parallel group
def get_check_parallel_group(self):
return self._check_comm_group
Expand Down
Loading

0 comments on commit 9e0bb91

Please sign in to comment.