-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[hybrid] out data parallel as optimizer sharding parallel #35593
[hybrid] out data parallel as optimizer sharding parallel #35593
Conversation
Thanks for your contribution! |
7c0c5d8
to
35f0a4b
Compare
bbac81d
to
f5a5597
Compare
@@ -43,6 +43,8 @@ message ShardingConfig { | |||
optional bool pp_allreduce_in_optimize = 10 [ default = false ]; | |||
optional int32 pp_degree = 11 [ default = 1 ]; | |||
optional bool optimize_cast = 12 [ default = false ]; | |||
// Optimizer sharding. Temporary plans and may be deprecated | |||
optional bool _dp_as_optimizer_sharding = 13 [ default = false ]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not new a config call stage, and allow two value: stage=1 and stage=3 by now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
ctx->Inputs("X").size(), ctx->Outputs("Out").size())); | ||
auto x_dims = ctx->GetInputsDim("X"); | ||
ctx->SetOutputsDim("Out", x_dims); | ||
if (ctx->HasInputs("X") || ctx->HasOutputs("Out")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why supports op without input/output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
Others
Describe
ShardingConfig
中添加_dp_as_optimizer_sharding
设置,将最外层的data parallel当做optimizer sharding parallel,对optimizer进行切分,各个sharding rank存储自己的优化器状态,可以减少优化器相关的persistable var的存储,减少显存占用。在梯度通信上,使用c_reduce_sum通信,更新完参数后使用c_broadcast广播,在通信量复杂度上与c_allreduce_sum一致。
目前
_dp_as_optimzier_sharding
的方案不是最终的方案,只是一个临时的方案。在存在AMP和GlobalGradientClip的情况,对应使用到的
check_finite_and_unscale_op
、update_loss_scaling_op
、sum_op
的输入会被裁剪为空。对于AMP,我们修改check_finite_and_unscale_op
和update_loss_scaling_op
的逻辑,即使输入为空也可正常执行。对于GlobalGradientClip,我们将sum_op
替换为fill_constant(0.0)
,并不会对程序数值计算正确性造成影响。该边界条件测试程序可见 https://gist.github.com/wangxicoding/d3b27289a545f62bec5130fc2952a542
_dp_as_optimizer_sharding
已支持fuse_allreduce和fuse_grad_merge。后续TODO: 支持optimize_cast,在前反向中使用fp16的梯度,可以减少cast数量,减少fp32参数的存储,减少参数广播的通信量。精度测试
Ernie3.0,base模型,单机8卡
baseline=2mp+2pp+2dp, optimzier_sharding=2mp+2pp+2opt_sharding
显存测试
Ernie3.0,单机八卡