Skip to content

Commit

Permalink
[cherry-pick][hybrid performance] optim the grad fuse for pipeline mo…
Browse files Browse the repository at this point in the history
…de by sorting the grad by dtype (#35070) (#35300)
  • Loading branch information
FeixLiu committed Aug 31, 2021
1 parent e931cd1 commit e69cc21
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions python/paddle/fluid/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5223,6 +5223,9 @@ def _accumulate_gradients_with_fuse(self, main_block, fp16, fused_size):
if len(grad_param_pairs) == 0:
return

grad_param_pairs = self._sort_grad_param_by_dtype(main_block,
grad_param_pairs)

grad_param_segments = []
merged_suffix = '@MERGED@FP16' if fp16 else '@MERGED'
dtype = paddle.float16 if fp16 else paddle.float32
Expand Down Expand Up @@ -5416,6 +5419,24 @@ def _accumulate_gradients_with_fuse(self, main_block, fp16, fused_size):

return fused_merged_gradients

def _sort_grad_param_by_dtype(self, main_block, grad_param_pairs):
# sort the grad param paris by the dtype
fp16_pairs = []
fp32_pairs = []
other_pairs = []
for pairs in grad_param_pairs:
dtype = main_block.var(pairs[0]).dtype
if dtype == paddle.float32:
fp32_pairs.append(pairs)
elif dtype == paddle.float16:
fp16_pairs.append(pairs)
else:
other_pairs.append(pairs)
sorted_pairs = fp16_pairs
sorted_pairs.extend(fp32_pairs)
sorted_pairs.extend(other_pairs)
return sorted_pairs

def _get_var_size(self, var):
dtype_to_size = {
core.VarDesc.VarType.FP16: 2,
Expand Down

0 comments on commit e69cc21

Please sign in to comment.