Skip to content

Commit

Permalink
fix_import_distribute_bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Baibaifan committed Mar 10, 2022
1 parent 9262a93 commit 7e4d84f
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
import paddle.distributed as dist
from paddle.optimizer import Optimizer
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.distributed.collective import _get_global_group
from paddle.distributed.collective import _get_global_group, new_group, broadcast, wait

from ...utils.internal_storage import ParamStorage, GradStorage
from ...meta_parallel.sharding.sharding_utils import Type, device_guard, ShardingClipGrad
Expand Down Expand Up @@ -91,8 +90,8 @@ def __init__(self,
filter(lambda x: x.trainable and x.dtype == Type.fp16.value,
self._local_params))) > 0

self.group = dist.new_group(_get_global_group()
.ranks) if group is None else group
self.group = new_group(_get_global_group()
.ranks) if group is None else group

self.world_size = self.group.nranks
self.rank = self.group.rank
Expand Down Expand Up @@ -141,14 +140,14 @@ def _sync_params_and_buffers(self):
"""

for p in self._local_params:
dist.broadcast(
broadcast(
p,
src=self._global_root_rank,
group=self.group,
use_calc_stream=True)

# Multi stream operation will be supported later
dist.wait(tensor=p, group=self.group, use_calc_stream=True)
wait(tensor=p, group=self.group, use_calc_stream=True)

def _generate_master_params(self, trainable_params):
if self.offload:
Expand Down Expand Up @@ -385,6 +384,12 @@ def minimize(self):
raise RuntimeError(
"optimizer.minimize() not support now, please use optimizer.step()")

def set_state_dict(self, state_dict):
self._optim.set_state_dict(state_dict)

def state_dict(self):
return self._optim.state_dict()

def _clear_cache(self):
self.__segment_params.clear()
self._dtype_rank_params.clear()
Expand All @@ -399,14 +404,14 @@ def _broadcast_params(self):
# Exchange all the shards with the other ranks
for dtype_per_rank in self.param_storages.values():
for dst_rank, internal_storage in dtype_per_rank.items():
dist.broadcast(
broadcast(
tensor=internal_storage.buffer,
src=self.group.ranks[dst_rank],
group=self.group,
use_calc_stream=True)

# Multi stream operation will be supported later
dist.wait(
wait(
tensor=internal_storage.buffer,
group=self.group,
use_calc_stream=True)
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import paddle
from paddle import nn
import paddle.distributed as dist
from paddle.distributed import collective as dist
from paddle.distributed.collective import _get_global_group

from ...utils.internal_storage import GradStorage
Expand Down Expand Up @@ -158,6 +158,17 @@ def forward(self, *inputs, **kwargs):

return fw

def set_state_dict(self, state_dict, use_structured_name=True):
self._layer.set_state_dict(
state_dict, use_structured_name=use_structured_name)

def state_dict(self,
destination=None,
include_sublayers=True,
structured_name_prefix=""):
return self._layer.state_dict(
destination=None, include_sublayers=True, structured_name_prefix="")

def _clear_gradients(self):
"""
Set zero to the gradient of the optimizer's current rank trainable parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,16 @@
import functools
import numpy as np
from itertools import chain
from functools import reduce
from types import MethodType
from collections import deque, OrderedDict

import paddle
from paddle import nn
from paddle.autograd import PyLayer
import paddle.fluid.core as core
import paddle.distributed as dist
from paddle.fluid.framework import ParamBase
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.distributed import collective as dist
from paddle.distributed.collective import _get_global_group

from .sharding_utils import Type, ShardingClipGrad, device_guard
Expand Down Expand Up @@ -249,6 +248,17 @@ def forward(self, *inputs, **kwargs):

return fw

def set_state_dict(self, state_dict, use_structured_name=True):
self._layer.set_state_dict(
state_dict, use_structured_name=use_structured_name)

def state_dict(self,
destination=None,
include_sublayers=True,
structured_name_prefix=""):
return self._layer.state_dict(
destination=None, include_sublayers=True, structured_name_prefix="")

def _handle_unslice_params(self):
buffer_size = dict()
buffer_size[Type.fp32.value] = 0
Expand Down Expand Up @@ -523,7 +533,7 @@ def _register_backward_hooks(self):

def _get_allreduce_fn(self, param):
@paddle.autograd.no_grad()
def reduce(*_):
def allreduce_(*_):
if param.name in self._task_flow.full_grad.keys():
full_grad = self._task_flow.full_grad[param.name]
# Only support sync allreduce current rank's layer now
Expand Down Expand Up @@ -573,7 +583,7 @@ def reduce(*_):
if self._offload:
param.fw_storage = _device2cpu(param.fw_storage, True)

return reduce
return allreduce_

def _param2align(self, param):
# CUDA alignment 256 bytes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from types import MethodType

import paddle
import paddle.distributed as dist
from paddle import _C_ops
from paddle.fluid import core
from paddle.fluid import layers
Expand Down
32 changes: 30 additions & 2 deletions python/paddle/fluid/tests/unittests/dygraph_sharding_stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import numpy as np
import argparse
import tempfile
import ast
import time
import paddle
Expand Down Expand Up @@ -88,7 +91,8 @@ def train_mlp(model,
batch_size=100,
use_pure_fp16=False,
accumulate_grad=False,
opt_group=False):
opt_group=False,
save_model=False):
if sharding_stage == "dp":
hcg = fleet.get_hybrid_communicate_group()
group = hcg.get_check_parallel_group()
Expand Down Expand Up @@ -147,6 +151,9 @@ def train_mlp(model,
if accumulate_grad:
optimizer.step()
optimizer.clear_grad()

if save_model:
return model, optimizer
return model.parameters()


Expand All @@ -158,11 +165,13 @@ def test_dp_stage2():
mlp3 = MLP()
mlp4 = MLP()
mlp5 = MLP()
mlp6 = MLP()
mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict)
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
mlp5.set_state_dict(state_dict)
mlp6.set_state_dict(state_dict)

# DP VS stage2
dp_params = train_mlp(
Expand All @@ -186,10 +195,29 @@ def test_dp_stage2():

# stage2 param list VS param group
stage2_params = train_mlp(
mlp2, sharding_stage=2, use_pure_fp16=False, opt_group=True)
mlp5, sharding_stage=2, use_pure_fp16=False, opt_group=True)
for i in range(len(dp_params)):
np.testing.assert_allclose(
dp_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6)

# save/load model
output_dir = tempfile.mkdtemp()
model_file = os.path.join(output_dir, "model.pdmodel")
optimizer_file = os.path.join(output_dir, "model.pdopt")
model_stage2, optimizer_stage2 = train_mlp(
mlp6,
sharding_stage=2,
use_pure_fp16=False,
opt_group=False,
save_model=True)
paddle.save(model_stage2.state_dict(), model_file)
paddle.save(optimizer_stage2.state_dict(), optimizer_file)
m_state_dict = paddle.load(model_file)
opt_state_dict = paddle.load(optimizer_file)
model_stage2.set_state_dict(m_state_dict)
optimizer_stage2.set_state_dict(opt_state_dict)
shutil.rmtree(output_dir)

return


Expand Down
34 changes: 30 additions & 4 deletions python/paddle/fluid/tests/unittests/dygraph_sharding_stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import numpy as np
import argparse
import ast
Expand Down Expand Up @@ -84,7 +87,8 @@ def train_mlp(model,
batch_size=100,
opt_group=False,
sync_comm=False,
test_minimize=False):
test_minimize=False,
save_model=False):
group = paddle.distributed.new_group([0, 1])
if opt_group:
optimizer = optimizer_setting(
Expand Down Expand Up @@ -162,12 +166,15 @@ def train_mlp(model,
optimizer.clear_grad()
if sharding_stage == 3:
model.get_all_parameters()

if save_model:
return model, optimizer
return model.parameters()


def test_stage2_stage3():
mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8, mlp9 = MLP(), MLP(
), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP()
mlp, mlp1, mlp2, mlp3, mlp4, mlp5, mlp6, mlp7, mlp8, mlp9, mlp10 = MLP(
), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP(), MLP()
state_dict = mlp.state_dict()
mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict)
Expand All @@ -178,6 +185,7 @@ def test_stage2_stage3():
mlp7.set_state_dict(state_dict)
mlp8.set_state_dict(state_dict)
mlp9.set_state_dict(state_dict)
mlp10.set_state_dict(state_dict)

# fp32
stage2_params = train_mlp(
Expand Down Expand Up @@ -238,9 +246,27 @@ def test_stage2_stage3():
np.testing.assert_allclose(
stage3_params[i].numpy(), stage3_params_re[i].numpy(), rtol=1e-6)

# save/load model
output_dir = tempfile.mkdtemp()
model_file = os.path.join(output_dir, "model.pdmodel")
optimizer_file = os.path.join(output_dir, "model.pdopt")
model_stage3, optimizer_stage3 = train_mlp(
mlp9,
sharding_stage=3,
use_pure_fp16=False,
opt_group=False,
save_model=True)
paddle.save(model_stage3.state_dict(), model_file)
paddle.save(optimizer_stage3.state_dict(), optimizer_file)
m_state_dict = paddle.load(model_file)
opt_state_dict = paddle.load(optimizer_file)
model_stage3.set_state_dict(m_state_dict)
optimizer_stage3.set_state_dict(opt_state_dict)
shutil.rmtree(output_dir)

# check optimizer.minimize() error
train_mlp(
mlp9,
mlp10,
sharding_stage=3,
use_pure_fp16=False,
opt_group=False,
Expand Down

1 comment on commit 7e4d84f

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.