Skip to content

Commit

Permalink
[Trainer] Remove dp group with group_sharded_parallel checks. (#7507)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZHUI committed Dec 8, 2023
1 parent a74916d commit 9c279f7
Showing 1 changed file with 10 additions and 34 deletions.
44 changes: 10 additions & 34 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,6 @@
from paddle.fluid.dataloader.dataloader_iter import _DataLoaderIterBase


def is_dp_group_support_in_group_sharded_parallel():
return "dp_group" in set(inspect.signature(paddle.distributed.sharding.group_sharded_parallel).parameters.keys())


__all__ = ["Trainer"]


Expand Down Expand Up @@ -897,30 +893,19 @@ def train(
and (step + 1) == steps_in_epoch
):
self.timers and self.timers("forward-backward").stop()
# Maunally collect gradients when group_sharded_parallel can't accept dp_group
# Case 1: Use sharding stage 2/3 with dp
# Case 2: Use recompute and dp
# Maunally collect gradients
# Case 1: Use recompute and dp
# Case 2: Hack dp with master_grad
# Case 3: Pipeline or sharding overlap
# local_rank != -1 don't means dp in networks.
self.timers and self.timers("all-reduce").start()

if self.sharding and ShardingOption.SHARD_OP not in self.args.sharding:
if self.args.data_parallel_degree > 1 and not is_dp_group_support_in_group_sharded_parallel():
fused_allreduce_gradients(model.parameters(), fleet.get_hybrid_communicate_group())
if ShardingOption.FULL_SHARD in self.args.sharding:
# Why need sync on parm again ?
# TODO: fix this.
for p in model.parameters():
if hasattr(p, "bw_storage"):
assert p.grad is None, "This case shouldn't happen."
p.bw_storage.scale_(1.0 / self.dp_group.nranks)
paddle.distributed.all_reduce(p.bw_storage, group=self.dp_group)

# Case 2: Use recompute and dp / sharding stage1,
# Case 1: Use recompute and dp / sharding stage1,
# manualy collect gradient for dp.
elif args.recompute and availiable_no_sync:
if args.recompute and availiable_no_sync:
fused_allreduce_gradients(list(model.parameters()), None)

# Case 3: hack dp with master_grad
# Case 2: hack dp with master_grad
if dp_master_grad and not (args.recompute and availiable_no_sync):
fused_allreduce_gradients(list(model.parameters()), None)

Expand All @@ -931,7 +916,7 @@ def train(
enable_delay_scale_loss = "enable_delay_scale_loss" in pipeline_parallel_config
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config

# Pipeline parallel mode, overlap with dp
# Case 3: Pipeline parallel mode, overlap with dp
if isinstance(self.optimizer, HybridParallelOptimizer) and not self.do_grad_scaling:
parameters_list = _obtain_optimizer_parameters_list(self.optimizer._inner_opt)

Expand Down Expand Up @@ -1696,14 +1681,6 @@ def get_expected_keys(inputs, keys):
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer)
self.optimizer = fleet.distributed_optimizer(self.optimizer)
else:
# sync params (broadcast) buffers in dp group
if not is_dp_group_support_in_group_sharded_parallel() and self.args.data_parallel_degree > 1:
from paddle.distributed.parallel import sync_params_buffers

hcg = fleet.get_hybrid_communicate_group()
dp_group = hcg.get_data_parallel_group()
sync_params_buffers(model, comm_group=dp_group, src_rank=dp_group.ranks[0])

cpu_offload = ShardingOption.OFFLOAD in self.args.sharding
assert self.optimizer is not None, "optimizer is empty!"
level = None
Expand All @@ -1717,9 +1694,8 @@ def get_expected_keys(inputs, keys):
# add dp_group and exclude_layer params
# https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/distributed/sharding/group_sharded_parallel_cn.html#group-sharded-parallel
extra_kwargs = {}
if is_dp_group_support_in_group_sharded_parallel():
extra_kwargs["dp_group"] = self.dp_group
extra_kwargs["exclude_layer"] = ["GroupNorm"]
extra_kwargs["dp_group"] = self.dp_group
extra_kwargs["exclude_layer"] = ["GroupNorm"]

if self.args.amp_master_grad:
assert (
Expand Down

0 comments on commit 9c279f7

Please sign in to comment.