Skip to content

Commit

Permalink
Cherry pick for dygraph pp (#46876)
Browse files Browse the repository at this point in the history
* bug fix for virtual pipeline parallel (#45922)

* dont wait for send op under dygraph pp (#46209)

* [interleave pp] sync recv for 1f1b (#46399)

* [dygraph pp] all sync for allgather partial (#46483)
  • Loading branch information
FeixLiu committed Oct 11, 2022
1 parent 6a6c749 commit 9cc3f69
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def get_stage_from_index(self, layer_idx):
for virtual_pp_rank in range(self._num_virtual_pipeline_stages):
# Mapping the virtual pipeline stage to the real pipeline stage.
# start_idx marks the start of a new virtual pp stage.
start_idx = virtual_pp_rank * self._num_virtual_pipeline_stages
start_idx = virtual_pp_rank * self._num_stages
for stage in range(self._num_stages):
# stage mark the real pp stage
if self.segment_parts[start_idx +
Expand Down Expand Up @@ -484,7 +484,7 @@ def _segment_network_for_interleave(self, seg_method):
", ".join(str(arg) for arg in self.segment_parts))

for i in range(self._stage_id, self._total_stages_with_virtual_stages,
self._num_virtual_pipeline_stages):
self._num_stages):
# If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers.
# Layers [0, 1], [4, 5] will be assigned to the first real pp stage.
# Layers [2, 3], [6, 7] will be assigned to the second real pp stage.
Expand Down Expand Up @@ -529,7 +529,7 @@ def _print_segmentation_for_debug(self):
stage_to_virtual_stage_info = "stage {} contains virtual stages: ".format(
stage)
for i in range(stage, self._total_stages_with_virtual_stages,
self._num_virtual_pipeline_stages):
self._num_stages):
stage_to_virtual_stage_info += " {},".format(i)
logger.info(stage_to_virtual_stage_info)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def interleave_pipeline(self,

self.set_virtual_pipeline_rank(0)
self.input_tensors[0].append(
p2p.recv_forward(self.is_pipeline_first_stage()))
p2p.recv_forward(self.is_pipeline_first_stage(), sync_recv=False))

# run startup steps
for micro_step in range(startup_steps):
Expand Down Expand Up @@ -647,7 +647,8 @@ def interleave_pipeline(self,
if not forward_only:
if all_startup_steps:
self.output_tensor_grads[self.num_model_chunks - 1].append(
p2p.recv_backward(self.is_pipeline_last_stage()))
p2p.recv_backward(self.is_pipeline_last_stage(),
sync_recv=False))

for micro_step in range(steady_steps, num_steps):
# cooldown loop
Expand Down
Loading

0 comments on commit 9cc3f69

Please sign in to comment.