Skip to content

Commit

Permalink
optimize sharding stage3 offload (#39397)
Browse files Browse the repository at this point in the history
  • Loading branch information
Baibaifan committed Feb 9, 2022
1 parent c5affb7 commit b292dfb
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self,
self._offload = offload
self._sync_comm = sync_comm
# segmentation size
self._segment_size = segment_size if not offload else 0
self._segment_size = segment_size

global DEV
DEV = "cpu" if paddle.get_device() == "cpu" else paddle.get_device(
Expand Down Expand Up @@ -191,8 +191,23 @@ def _clear_gradients(self):
param.fw_storage._gradient_set_empty(False)
param.bw_storage._clear()
# 2.Handle unslice param
for grad_storage in self._grad_storages.values():
grad_storage.buffer.zero_()
if not self._offload:
for grad_storage in self._grad_storages.values():
grad_storage.buffer.zero_()
else:
for param in list(self._unslice_params):
param.clear_gradient(False)
param._gradient_set_empty(False)
tmp_var = param.cuda(DEV_ID)
param._clear()
if tmp_var.dtype == Type.fp32.value and param2dtype[
param.name] == Type.fp16.value:
tmp_var = paddle.cast(tmp_var, Type.fp16.value)
tmp_var._share_buffer_to(param)
tmp_var._clear()
for grad_storage in self._grad_storages.values():
grad_storage.manumal_relase()
grad_storage.rebuild()

# Update param memery slice
def _update_params_slice(self):
Expand Down Expand Up @@ -455,6 +470,21 @@ def _update_params(self):
group=self._group,
use_calc_stream=True)

if self._offload:
for param in list(self._unslice_params):
tmp_var = _device2cpu(param, convert_dtype=True)
tmp_var._share_buffer_to(param)
tmp_var._clear()

for grad_storage in self._grad_storages.values():
for p in grad_storage._params:
tmp_g = _device2cpu(p.grad, convert_dtype=True)
p.clear_gradient(False)
p._gradient_set_empty(False)
p._copy_gradient_from(tmp_g)
tmp_g._clear()
grad_storage.buffer._clear()

return update_list

def get_all_parameters(self, convert2cpu=False):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,15 @@ def _dygraph_clip(self, params_grads):
clip_var_fp16 = paddle.cast(clip_var, paddle.float16)

for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
if getattr(p, 'need_clip', True) is False or g is None:
continue
origin_state = g.stop_gradient
g.stop_gradient = True
if p.dtype == paddle.float16:
g.scale_(clip_var_fp16)
else:
g.scale_(clip_var)
g.stop_gradient = origin_state
p._reset_grad_inplace_version(True)

return params_grads
Expand Down

0 comments on commit b292dfb

Please sign in to comment.