From fad4b3b4964c16ec4519fb5d62218204ab1eaced Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Mon, 23 Aug 2021 15:46:17 +0800 Subject: [PATCH] [hybrid performance] optim the grad fuse for pipeline mode by sorting the grad by dtype (#35070) --- python/paddle/fluid/optimizer.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 58f61b77fd1fe..478ea75472717 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -5216,6 +5216,9 @@ def _accumulate_gradients_with_fuse(self, main_block, fp16, fused_size): if len(grad_param_pairs) == 0: return + grad_param_pairs = self._sort_grad_param_by_dtype(main_block, + grad_param_pairs) + grad_param_segments = [] merged_suffix = '@MERGED@FP16' if fp16 else '@MERGED' dtype = paddle.float16 if fp16 else paddle.float32 @@ -5409,6 +5412,24 @@ def _accumulate_gradients_with_fuse(self, main_block, fp16, fused_size): return fused_merged_gradients + def _sort_grad_param_by_dtype(self, main_block, grad_param_pairs): + # sort the grad param paris by the dtype + fp16_pairs = [] + fp32_pairs = [] + other_pairs = [] + for pairs in grad_param_pairs: + dtype = main_block.var(pairs[0]).dtype + if dtype == paddle.float32: + fp32_pairs.append(pairs) + elif dtype == paddle.float16: + fp16_pairs.append(pairs) + else: + other_pairs.append(pairs) + sorted_pairs = fp16_pairs + sorted_pairs.extend(fp32_pairs) + sorted_pairs.extend(other_pairs) + return sorted_pairs + def _get_var_size(self, var): dtype_to_size = { core.VarDesc.VarType.FP16: 2,