From 616c62f3ad5041a710ba9f622bcd2eb6028b2130 Mon Sep 17 00:00:00 2001 From: Baibaifan Date: Tue, 15 Feb 2022 11:06:22 +0000 Subject: [PATCH] optimizer sharding paramters --- .../sharding_optimizer_stage2.py | 24 ++++++++++- .../meta_parallel/sharding/sharding_stage2.py | 40 +------------------ .../meta_parallel/sharding/sharding_stage3.py | 19 ++++----- .../unittests/dygraph_sharding_stage2.py | 15 +++---- .../dygraph_sharding_stage2_offload.py | 5 +-- .../unittests/dygraph_sharding_stage3.py | 14 ++----- .../dygraph_sharding_stage3_offload.py | 6 +-- 7 files changed, 45 insertions(+), 78 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py index ea17f96f7a1ca..08baeae89ad4a 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/sharding_optimizer_stage2.py @@ -65,9 +65,9 @@ def __init__(self, params, optim, group=None, - broadcast_fp16=False, offload=False, device="gpu", + pertrain_sync_models=True, **kw): super().__init__(optim._learning_rate, params, kw) @@ -98,8 +98,12 @@ def __init__(self, self.world_size = self.group.nranks self.rank = self.group.rank + self._global_root_rank = 0 + + # Synchronous all ranks models + if pertrain_sync_models: + self._sync_params_and_buffers() - self.broadcast_fp16 = broadcast_fp16 self.param_storages = {} # {dtype: {rank: InternalStorage}} if isinstance(self._optim._grad_clip, ClipGradByGlobalNorm): @@ -132,6 +136,22 @@ def __init__(self, # Update optimizer parameters and adjust parameter storage and use according to rank. self._update_opt_status() + @paddle.no_grad() + def _sync_params_and_buffers(self): + """ + Sync all model states for all ranks + """ + + for p in self._local_params: + dist.broadcast( + p, + src=self._global_root_rank, + group=self.group, + use_calc_stream=True) + + # Multi stream operation will be supported later + dist.wait(tensor=p, group=self.group, use_calc_stream=True) + def _generate_master_params(self, trainable_params): if self.offload: for param in trainable_params: diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py index d884c416fa92c..e654f88f0b7b8 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py @@ -61,12 +61,10 @@ def __init__( sharding_optimizer, group=None, sync_buffers=False, - pertrain_sync_models=True, buffer_max_size=2**23, #8MB auto_refresh_trainable=True, device="gpu", - use_grad_storage=True, - accumulate_grads=False): + use_grad_storage=True): super().__init__() # training options @@ -81,9 +79,6 @@ def __init__( self._sync_buffers = sync_buffers self._auto_refresh_trainable = auto_refresh_trainable - # Gradient accumulation, Gradient flip - self._accumulate_grads = accumulate_grads - # Communication related attributes self._group = dist.new_group(_get_global_group() .ranks) if group is None else group @@ -128,16 +123,11 @@ def __init__( # Set backward pass hooks self._bw_hooks = [] - # Synchronous all ranks models - if pertrain_sync_models: - self._sync_params_and_buffers() - # Set tasks flow self._tasks_flow = deque() # Define optimizer step and clear_grad - if self._accumulate_grads: - self._redefine_opt_step() + self._redefine_opt_step() self._redefine_opt_clear() def forward(self, *inputs, **kwargs): @@ -313,9 +303,6 @@ def reduce(*_): # Change reduce information self._grad_reduced[index] = False - if not self._accumulate_grads: - param.grad.scale_(scale=self._world_size_scaling) - param._reset_grad_inplace_version(True) # Clear the gradient that does not belong to the current rank through the callback function def cleanup(): @@ -362,11 +349,6 @@ def reduce(*_): if grad_storage.all_checked_in: assert grad_storage.buffer is not None - # Normalize all ranks grad_storage - if not self._accumulate_grads: - grad_storage.buffer.scale_( - scale=self._world_size_scaling) - # Clearing up the grad_storage buffer def cleanup(): if dst_rank != self._rank: @@ -432,22 +414,6 @@ def _setup_backward_hooks(self): self._bw_hooks.append( param._register_backward_hook(reduce_function)) - @paddle.no_grad() - def _sync_params_and_buffers(self): - """ - Sync all model states for all ranks - """ - - for t in self._layer.parameters(): - dist.broadcast( - t, - src=self._global_root_rank, - group=self._group, - use_calc_stream=True) - - # Multi stream operation will be supported later - dist.wait(tensor=t, group=self._group, use_calc_stream=True) - def _setup_use_grad_storage(self): """ Integrate the parameters gradient into a continuous memory according to rank, and support the update of training parameters. @@ -555,8 +521,6 @@ def _rank_buffer_size(self, buffer_max_size, model_size): return rank_buffer_size def _redefine_opt_step(self): - if not self._accumulate_grads: - return grad_func = self._grad_scale for opt in self._sharding_optimizers: opt_step = opt.step diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index 00c72e28a6ffd..9f9811b9eb0fc 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -72,7 +72,6 @@ def __init__(self, device="gpu", segment_size=2**15, pertrain_sync_models=True, - accumulate_grads=False, offload=False, sync_comm=False): super().__init__() @@ -82,7 +81,6 @@ def __init__(self, self._layer = layer self._default_device = device self.__sync_buffers = sync_buffers - self._accumulate_grads = accumulate_grads self._offload = offload self._sync_comm = sync_comm # segmentation size @@ -190,6 +188,7 @@ def _clear_gradients(self): param.fw_storage.clear_gradient(False) param.fw_storage._gradient_set_empty(False) param.bw_storage._clear() + param.bw_storage = None # 2.Handle unslice param if not self._offload: for grad_storage in self._grad_storages.values(): @@ -446,13 +445,12 @@ def _update_params(self): param, "fw_storage"), "Find {} don't have fw_storage attribute".format( param.name) - - if self._accumulate_grads: - if self._offload: - with device_guard(device="cpu"): - param.bw_storage.scale_(scale=self._world_size_scaling) - else: + # Gradient average + if self._offload: + with device_guard(device="cpu"): param.bw_storage.scale_(scale=self._world_size_scaling) + else: + param.bw_storage.scale_(scale=self._world_size_scaling) param.fw_storage = _VarBaseWrapper(param) assert param.fw_storage.grad is None param.fw_storage._copy_gradient_from(param.bw_storage) @@ -526,8 +524,6 @@ def _get_allreduce_fn(self, param): def reduce(*_): if param.name in self._task_flow.full_grad.keys(): full_grad = self._task_flow.full_grad[param.name] - if not self._accumulate_grads: - full_grad.scale_(scale=self._world_size_scaling) # Only support sync allreduce current rank's layer now dist.all_reduce( tensor=full_grad, group=self._group, use_calc_stream=True) @@ -535,8 +531,7 @@ def reduce(*_): tensor=full_grad, group=self._group, use_calc_stream=True) start, end = self._param2buffer[param.name][self._rank] - if not self._accumulate_grads or param.bw_storage is None or not param.bw_storage.value( - ).get_tensor()._is_initialized(): + if param.bw_storage is None: param.bw_storage = core.VarBase( full_grad._slice(start, end)).detach().clone() if self._offload: diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py index 80acf7217e76f..06935e212c3cb 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py @@ -27,7 +27,7 @@ from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2 from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2 -seed = 2021 +seed = 2022 epoch = 2 linear_size = 1000 @@ -105,11 +105,7 @@ def train_mlp(model, params=model.parameters(), optim=optimizer, group=group) model = ShardingStage2( - model, - optimizer, - group=group, - buffer_max_size=2**21, - accumulate_grads=batch_size == 20) + model, optimizer, group=group, buffer_max_size=2**21) else: optimizer = fleet.distributed_optimizer(optimizer) model = fleet.distributed_model(model) @@ -140,6 +136,8 @@ def train_mlp(model, loss = paddle.nn.functional.cross_entropy(input=out, label=label) avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32)) + if batch_size == 20: + avg_loss = avg_loss / 5 avg_loss.backward() if not accumulate_grad: @@ -166,6 +164,7 @@ def test_dp_stage2(): mlp4.set_state_dict(state_dict) mlp5.set_state_dict(state_dict) + # DP VS stage2 dp_params = train_mlp( mlp1, sharding_stage="dp", use_pure_fp16=False, opt_group=False) stage2_params = train_mlp( @@ -174,7 +173,8 @@ def test_dp_stage2(): np.testing.assert_allclose( dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6) - stage2_params = train_mlp(mlp3, sharding_stage=2) + # stage2 accumulate grad + stage2_params = train_mlp(mlp3, sharding_stage=2, accumulate_grad=True) stage2_accumulate_grad = train_mlp( mlp4, sharding_stage=2, batch_size=20, accumulate_grad=True) for i in range(len(stage2_params)): @@ -184,6 +184,7 @@ def test_dp_stage2(): rtol=1e-5, atol=1e-5) + # stage2 param list VS param group stage2_params = train_mlp( mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=True) for i in range(len(dp_params)): diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py index 84ffe9094d812..39ba44815d940 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage2_offload.py @@ -43,13 +43,12 @@ def train_mlp(model, offload=False): optimizer = optimizer_setting(model=model, use_pure_fp16=True) model = paddle.amp.decorate(models=model, level='O2', save_dtype='float32') - scaler = paddle.amp.GradScaler(init_loss_scaling=32768) + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) scaler = ShardingScaler(scaler) optimizer = ShardingOptimizerStage2( params=model.parameters(), optim=optimizer, offload=offload) - model = ShardingStage2( - model, optimizer, buffer_max_size=2**21, accumulate_grads=False) + model = ShardingStage2(model, optimizer, buffer_max_size=2**21) train_reader = paddle.batch( reader_decorator(linear_size), batch_size=batch_size, drop_last=True) diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py index 9bb1f85f327c3..6b755cf4c2b59 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py @@ -101,18 +101,10 @@ def train_mlp(model, optimizer = ShardingOptimizerStage2( params=model.parameters(), optim=optimizer, group=group) model = ShardingStage2( - model, - optimizer, - group=group, - buffer_max_size=2**21, - accumulate_grads=batch_size == 20) + model, optimizer, group=group, buffer_max_size=2**21) elif sharding_stage == 3: model = ShardingStage3( - model, - optimizer=optimizer, - group=group, - accumulate_grads=batch_size == 20, - sync_comm=recompute) + model, optimizer=optimizer, group=group, sync_comm=recompute) # check optimizer.minimize() error if test_minimize: @@ -231,7 +223,7 @@ def test_stage2_stage3(): stage2_params[i].numpy(), stage3_params[i].numpy(), rtol=1e-4, - atol=1e-4) + atol=1e-3) # fp16 recompute stage3_params = train_mlp( diff --git a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3_offload.py b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3_offload.py index aa440549cf147..df7ba78d345a3 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3_offload.py +++ b/python/paddle/fluid/tests/unittests/dygraph_sharding_stage3_offload.py @@ -91,11 +91,7 @@ def train_mlp(model, scaler = ShardingScaler(scaler) model = ShardingStage3( - model, - optimizer=optimizer, - group=group, - offload=offload, - accumulate_grads=accumulate_grad) + model, optimizer=optimizer, group=group, offload=offload) train_reader = paddle.batch( reader_decorator(), batch_size=batch_size, drop_last=True)