Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dygraph]Integration sharding stage2 function #38151

Merged
merged 2 commits into from
Dec 19, 2021

Conversation

Baibaifan
Copy link
Contributor

@Baibaifan Baibaifan commented Dec 15, 2021

PR types

Performance optimization

PR changes

Others

Describe

Integration sharding stage2 function
1.Support group = None
2.Support param_groups for optimizer

import paddle
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2

fleet.init(is_collective=True)
group = paddle.distributed.new_group([0, 1])

# wrap model & optimizer 
model = model_class(...)
oss_optimizer = ShardingOptimizer(params=model.parameters(), optim=optimizer, group=group)
model = ShardingStage2(model, oss_optimizer, group=group)

# use optimizer as normal
img, label = data
label.stop_gradient = True
img.stop_gradient = True
out = model(img)

loss = paddle.nn.functional.cross_entropy(input=out, label=label)
oss_optimizer.step()
oss_optimizer.clear_grad()

image

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@Baibaifan Baibaifan force-pushed the integration_stage2_function branch 3 times, most recently from b1bf3cc to 7d5ad2e Compare December 15, 2021 06:35
@Baibaifan Baibaifan force-pushed the integration_stage2_function branch 3 times, most recently from 04d6d9f to 5d6cc91 Compare December 17, 2021 06:14
self._rank_buffer_size = {} # {dtype: {rank: numel+alignment}}
self._param2align = {} # {param.name: align}

# Default information
self._optim_defaults = kw
self._optim = optim
self._ori_parameter_list = copy.deepcopy(self._optim._parameter_list)
self._ori_param_groups = copy.deepcopy(self._optim._param_groups)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deepcopy increase memory..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修复,改成引用传递。

@@ -94,7 +94,7 @@ def __init__(self,
filter(lambda x: x.trainable and x.dtype == Type.fp16.value,
self._local_params))) > 0

assert group is not None, "Distributed communication group is must be gived"
assert group is not None, "Distributed communication group is must be given"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need support global group if group=None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已支持

@Baibaifan Baibaifan force-pushed the integration_stage2_function branch 3 times, most recently from 2715035 to 0f53247 Compare December 17, 2021 13:02
@Baibaifan Baibaifan changed the title Integration sharding stage2 function [Dygraph]Integration sharding stage2 function Dec 18, 2021
ForFishes
ForFishes previously approved these changes Dec 18, 2021
Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Baibaifan Baibaifan merged commit 327e505 into PaddlePaddle:develop Dec 19, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants