From 824a60225b937a662f1e9cc7954e6475b0089000 Mon Sep 17 00:00:00 2001 From: Baibaifan Date: Tue, 8 Feb 2022 08:55:12 +0000 Subject: [PATCH] optimize sharding stage3 offload --- .../meta_parallel/sharding/sharding_stage3.py | 36 +++++++++++++++++-- .../meta_parallel/sharding/sharding_utils.py | 7 ++-- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index 8bbf42b72f2d6..00c72e28a6ffd 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -86,7 +86,7 @@ def __init__(self, self._offload = offload self._sync_comm = sync_comm # segmentation size - self._segment_size = segment_size if not offload else 0 + self._segment_size = segment_size global DEV DEV = "cpu" if paddle.get_device() == "cpu" else paddle.get_device( @@ -191,8 +191,23 @@ def _clear_gradients(self): param.fw_storage._gradient_set_empty(False) param.bw_storage._clear() # 2.Handle unslice param - for grad_storage in self._grad_storages.values(): - grad_storage.buffer.zero_() + if not self._offload: + for grad_storage in self._grad_storages.values(): + grad_storage.buffer.zero_() + else: + for param in list(self._unslice_params): + param.clear_gradient(False) + param._gradient_set_empty(False) + tmp_var = param.cuda(DEV_ID) + param._clear() + if tmp_var.dtype == Type.fp32.value and param2dtype[ + param.name] == Type.fp16.value: + tmp_var = paddle.cast(tmp_var, Type.fp16.value) + tmp_var._share_buffer_to(param) + tmp_var._clear() + for grad_storage in self._grad_storages.values(): + grad_storage.manumal_relase() + grad_storage.rebuild() # Update param memery slice def _update_params_slice(self): @@ -455,6 +470,21 @@ def _update_params(self): group=self._group, use_calc_stream=True) + if self._offload: + for param in list(self._unslice_params): + tmp_var = _device2cpu(param, convert_dtype=True) + tmp_var._share_buffer_to(param) + tmp_var._clear() + + for grad_storage in self._grad_storages.values(): + for p in grad_storage._params: + tmp_g = _device2cpu(p.grad, convert_dtype=True) + p.clear_gradient(False) + p._gradient_set_empty(False) + p._copy_gradient_from(tmp_g) + tmp_g._clear() + grad_storage.buffer._clear() + return update_list def get_all_parameters(self, convert2cpu=False): diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py index ee281a0a044f4..0a42b993d5bf2 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_utils.py @@ -131,14 +131,15 @@ def _dygraph_clip(self, params_grads): clip_var_fp16 = paddle.cast(clip_var, paddle.float16) for p, g in params_grads: - if g is None: - continue - if getattr(p, 'need_clip', True) is False: + if getattr(p, 'need_clip', True) is False or g is None: continue + origin_state = g.stop_gradient + g.stop_gradient = True if p.dtype == paddle.float16: g.scale_(clip_var_fp16) else: g.scale_(clip_var) + g.stop_gradient = origin_state p._reset_grad_inplace_version(True) return params_grads