Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hybrid] out data parallel as optimizer sharding parallel #35593

Merged

Conversation

wangxicoding
Copy link
Contributor

@wangxicoding wangxicoding commented Sep 8, 2021

PR types

New features

PR changes

Others

Describe

  1. ShardingConfig中添加_dp_as_optimizer_sharding设置,将最外层的data parallel当做optimizer sharding parallel,对optimizer进行切分,各个sharding rank存储自己的优化器状态,可以减少优化器相关的persistable var的存储,减少显存占用。
    在梯度通信上,使用c_reduce_sum通信,更新完参数后使用c_broadcast广播,在通信量复杂度上与c_allreduce_sum一致。
    目前_dp_as_optimzier_sharding的方案不是最终的方案,只是一个临时的方案。
  2. 由于sharding切分optimizer存在rank未分配到参数的情况,比如只有一个参数,但有两个sharding,那么后一个sharding的optimizer则不会分配到参数。
    在存在AMP和GlobalGradientClip的情况,对应使用到的check_finite_and_unscale_opupdate_loss_scaling_opsum_op的输入会被裁剪为空。对于AMP,我们修改check_finite_and_unscale_opupdate_loss_scaling_op的逻辑,即使输入为空也可正常执行。对于GlobalGradientClip,我们将sum_op替换为fill_constant(0.0),并不会对程序数值计算正确性造成影响。
    该边界条件测试程序可见 https://gist.github.com/wangxicoding/d3b27289a545f62bec5130fc2952a542
  3. _dp_as_optimizer_sharding已支持fuse_allreduce和fuse_grad_merge。后续TODO: 支持optimize_cast,在前反向中使用fp16的梯度,可以减少cast数量,减少fp32参数的存储,减少参数广播的通信量。

精度测试

Ernie3.0,base模型,单机8卡
baseline=2mp+2pp+2dp, optimzier_sharding=2mp+2pp+2opt_sharding
image

显存测试

Ernie3.0,单机八卡

模型配置
hidden_size 3072
num_attention_heads 48
num_hidden_layers 39
num_sharding_layers 36
branch_hidden_size 256
branch_num_attention_heads 4
卡id baseline(MB) opt_sharding(MB) 节省显存(MB)
0 23298 18720 4578
1 23274 18690 4584
2 27442 22296 5146
3 27414 22274 5140

@paddle-bot-old
Copy link

paddle-bot-old bot commented Sep 8, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -43,6 +43,8 @@ message ShardingConfig {
optional bool pp_allreduce_in_optimize = 10 [ default = false ];
optional int32 pp_degree = 11 [ default = 1 ];
optional bool optimize_cast = 12 [ default = false ];
// Optimizer sharding. Temporary plans and may be deprecated
optional bool _dp_as_optimizer_sharding = 13 [ default = false ];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not new a config call stage, and allow two value: stage=1 and stage=3 by now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update later

@JZ-LIANG JZ-LIANG self-requested a review September 14, 2021 11:48
Copy link
Contributor

@JZ-LIANG JZ-LIANG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
if (ctx->HasInputs("X") || ctx->HasOutputs("Out")) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why supports op without input/output?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Copy link

@sandyhouse sandyhouse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wangxicoding wangxicoding merged commit 7846570 into PaddlePaddle:develop Sep 15, 2021
@wangxicoding wangxicoding deleted the out_dp_as_optimzier_sharding branch September 15, 2021 03:32
AnnaTrainingG pushed a commit to AnnaTrainingG/Paddle that referenced this pull request Sep 29, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants