Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HybridParallel] Support save/load for PipeLineParallel #34768

Merged
merged 8 commits into from
Aug 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import paddle
import re
import glob
import os
import numpy as np
import random
from functools import partial

import paddle
from paddle.fluid.dygraph.layers import Layer
from ...utils.log_util import logger, layer_to_str
from functools import partial

__all__ = []

Expand Down Expand Up @@ -310,3 +316,48 @@ def forward(self, input):
for layer in self.run_function:
input = layer(input)
return input

def save_state_dict(self, path):
if self._topo.get_coord(self.global_rank).data != 0:
return

def _offset_dirname(ckpt_dir, local_layer_idx):
idx = local_layer_idx + self._start_pos
model_rank = self._topo.get_coord(self.global_rank).model
rank_message = "-tensor_" + "{:0>2d}".format(model_rank)
layer_save_path = os.path.join(ckpt_dir,
'layer_{:0>2d}'.format(idx))
layer_save_path = layer_save_path + rank_message + '-model_states.pdparams'
return layer_save_path

os.makedirs(path, exist_ok=True)
for idx, layer in enumerate(self.run_function):
model_save_path = _offset_dirname(path, idx)
if not hasattr(layer, 'state_dict'):
continue
paddle.save(layer.state_dict(), model_save_path)

logger.info("save model state successfully...")

def set_state_dir(self, path):
assert os.path.exists(
path), "{} not found, please check the path".format(path)

for idx, layer in enumerate(self.run_function):
if not hasattr(layer, 'set_state_dict'):
continue
layer_idx = idx + self._start_pos
layer_save_path = os.path.join(path,
'layer_{0:0>2d}'.format(layer_idx))
model_files = glob.glob(layer_save_path + "*model_states.pdparams")
model_files.sort()
mp_rank = self._topo.get_coord(self.global_rank).model
mp_world_size = self._topo.get_dim('model')
num_files = len(model_files)

load_param_path = model_files[mp_rank * num_files // mp_world_size]
model_state_dict = paddle.load(load_param_path)
layer.set_state_dict(model_state_dict)

self._synchronize_shared_weights()
logger.info("load model state successfully...")
114 changes: 114 additions & 0 deletions python/paddle/fluid/tests/unittests/hybrid_parallel_pp_save_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division
from __future__ import print_function

import unittest
import paddle
import numpy as np
import random
import os
import shutil
import tempfile
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from hybrid_parallel_pp_transformer import ModelPipe, set_random_seed

batch_size = 8
length = 8
micro_batch_size = 2
vocab_size = 128


class TestDistPPSaveLoadTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size
}
fleet.init(is_collective=True, strategy=strategy)

def test_pp_model(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
dp_id = hcg.get_data_parallel_rank()
pp_id = hcg.get_stage_id()
rank_id = dist.get_rank()
topology = hcg.topology()
set_random_seed(1024, dp_id, rank_id)

model = ModelPipe(topology)
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer = paddle.optimizer.SGD(learning_rate=scheduler,
parameters=model.parameters())

model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
output_dir = tempfile.mkdtemp()

# warmup step
for step_id in range(2):
x_data = np.random.randint(0, vocab_size, size=[batch_size, length])
x = paddle.to_tensor(x_data)
x.stop_gradient = True
loss = model.train_batch([x, x], optimizer, scheduler)

model._layers.save_state_dict(output_dir)
paddle.save(optimizer.state_dict(),
os.path.join(output_dir, "model_state.pdopt"))

# construct data
test_steps = 5
np_data = np.random.randint(
0, vocab_size, size=[test_steps, batch_size, length])

origin_loss = []
for step_id in range(5):
x_data = np_data[step_id, :]
x = paddle.to_tensor(x_data)
x.stop_gradient = True
loss = model.train_batch([x, x], optimizer, scheduler)
origin_loss.append(loss.numpy())

# test step
model._layers.set_state_dir(output_dir)
opt_dict = paddle.load(os.path.join(output_dir, "model_state.pdopt"))
optimizer.set_state_dict(opt_dict)

for step_id in range(5):
x_data = np_data[step_id, :]
x = paddle.to_tensor(x_data)
x.stop_gradient = True
loss = model.train_batch([x, x], optimizer, scheduler)
print("origin loss: ", origin_loss[step_id], "current loss: ",
loss.numpy())
np.testing.assert_allclose(loss.numpy(), origin_loss[step_id])

# finally, remove the model/optimizer path
shutil.rmtree(output_dir)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def forward(self, x, mask):
product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_model**-0.5)

weights = F.softmax(product + mask)
weights = F.dropout(weights, 0.2)
# TODO(shenliang03) For save/load in PipeLineParallel, can’t support dropout temporarily.
# weights = F.dropout(weights, 0.2)
tgt = layers.matmul(weights, v)
residual = tgt
tgt = self.norm1(tgt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def test_pipeline_parallel(self):
def test_hybrid_parallel_transformer(self):
self.run_mnist_2gpu('hybrid_parallel_pp_transformer.py')

def test_hybrid_parallel_transformer(self):
self.run_mnist_2gpu('hybrid_parallel_pp_save_load.py')


if __name__ == "__main__":
unittest.main()