Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharding stage2 overlap #46495

Merged
merged 1 commit into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ def __init__(self,
# Default information
self._optim = optim

# sharing stage 2 comm overlap flag
self._comm_overlap = False
# record the last task used for comm overlap for sharding stage 2
self._comm_task = None

assert hasattr(self._optim, "_master_weights"
), "Must use optimizer with _master_weights attribute"

Expand Down Expand Up @@ -157,6 +162,17 @@ def _sync_params_and_buffers(self):
group=self._group,
sync_op=True)

def _update_task(self, task):
if self._comm_overlap:
assert task is not None
# Only track of the last reduce task.
# Since all tasks are on the same stream, only need to wait the last one.
# After waiting for the last reduce task, all reduce tasks before have already finished.
self._comm_task = task

def _set_comm_overlap(self, comm_overlap):
self._comm_overlap = comm_overlap

def _generate_master_params(self, trainable_params):
if self.offload:
for param in trainable_params:
Expand Down Expand Up @@ -364,7 +380,8 @@ def step(self):
"""
A wrapper for Optimizer's step function to finish the update operation of the optimizer.
"""

# This method won't be called directly by opt.step()!
# The _redefine_opt_step() in class GroupShardedStage2 will wrap this function.
if self.offload:
params_list = [self.offload_params.buffer]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def __init__(
for optim in self._sharding_optimizers:
self._all_params.extend(list(optim.local_params))

# sharing stage 2 comm overlap flag
self._comm_overlap = False

self._trainable_params = []
self._grad_reduced = []
self._trainable_param2rank = {}
Expand Down Expand Up @@ -306,6 +309,18 @@ def _clear_counters(self):
for grad_storage in self._grad_storage_list:
grad_storage.reset_checked_in()

def _set_comm_overlap(self, comm_overlap):
# Hacky way to not add an extra parameter to the `group_sharded_parallel` funct.
# User should use this like:
# model, optimizer, scaler = group_sharded_parallel(...)
# model._set_comm_overlap(True)
self._comm_overlap = comm_overlap
if self._comm_overlap:
assert len(
self._sharding_optimizers
) == 1, "Only support comm overlap strategy for single optimizer"
self._sharding_optimizers[0]._set_comm_overlap(comm_overlap)

def _get_reduce_fn(self, index, param, dst_rank):
"""
There are two ways to reduce gradient.
Expand Down Expand Up @@ -337,11 +352,12 @@ def cleanup():
del tmp_grad
param.clear_gradient(False)

# Synchronize the reduce parameter gradient
collective.reduce(tensor=param.grad,
dst=self._group.ranks[dst_rank],
group=self._group)
# TODO (Baibaifan) Asynchronous the reduce parameter gradient
# Synchronize the reduce parameter gradient asynchronize
self._sharding_optimizers[0]._update_task(
collective.reduce(tensor=param.grad,
dst=self._group.ranks[dst_rank],
group=self._group,
sync_op=not self._comm_overlap))

# Clear the task flow and trigger callback to clear the redundant gradient
# self._clear_task_flow()
Expand Down Expand Up @@ -385,12 +401,13 @@ def cleanup():

# Reduce the bucket
grad_storage.sent = True
# Synchronize the reduce parameter gradient
collective.reduce(
tensor=grad_storage.buffer,
dst=self._group.ranks[grad_storage.destination],
group=self._group)
# TODO (Baibaifan) Asynchronous the reduce parameter gradient
# Synchronize the reduce parameter gradient asynchronize
self._sharding_optimizers[0]._update_task(
collective.reduce(
tensor=grad_storage.buffer,
dst=self._group.ranks[grad_storage.destination],
group=self._group,
sync_op=not self._comm_overlap))

cleanup()

Expand Down Expand Up @@ -528,6 +545,10 @@ def _redefine_opt_step(self):
opt_step = opt.step

def _opt_step(self):
if self._comm_overlap:
# Wait for the last reduce task. This wait must before grad scale function.
assert self._comm_task is not None
self._comm_task.wait()
grad_func()
opt_step()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# -*- coding: UTF-8 -*-

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import numpy as np
import argparse
import tempfile
import ast
import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.fluid.dygraph import nn
from paddle.fluid.framework import _test_eager_guard

from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import GroupShardedOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import GroupShardedStage2

seed = 2022
epoch = 2
linear_size = 1000

np.random.seed(seed)
paddle.seed(seed)


class MLP(fluid.Layer):

def __init__(self, linear_size=1000, param_attr=None, bias_attr=None):
super(MLP, self).__init__()

self._linear1 = Linear(linear_size, linear_size)
self._linear2 = Linear(linear_size, linear_size)
self._linear3 = Linear(linear_size, 10)

def forward(self, inputs):
y = self._linear1(inputs)
y = self._linear2(y)
y = self._linear3(y)
return y


def reader_decorator(linear_size=1000):

def __reader__():
for _ in range(100):
img = np.random.rand(linear_size).astype('float32')
label = np.ones(1).astype('int64')
yield img, label

return __reader__


def optimizer_setting(model, use_pure_fp16, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(parameters=[{
"params": model.parameters(),
}] if opt_group else model.parameters(),
learning_rate=0.001,
weight_decay=0.00001,
grad_clip=clip,
multi_precision=use_pure_fp16)

return optimizer


def train_mlp(model,
sharding_stage,
batch_size=100,
use_pure_fp16=False,
accumulate_grad=False,
opt_group=False,
save_model=False,
test_minimize=False):
if sharding_stage != "dp":
group = paddle.distributed.new_group([0, 1], backend="nccl")
if opt_group:
optimizer = optimizer_setting(model=model,
use_pure_fp16=use_pure_fp16,
opt_group=opt_group)
else:
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)

if sharding_stage == 2:
optimizer = GroupShardedOptimizerStage2(
params=optimizer._parameter_list, optim=optimizer, group=group)
model = GroupShardedStage2(model,
optimizer,
group=group,
buffer_max_size=2**21)
model._set_comm_overlap(True)
else:
model = paddle.DataParallel(model)

# check optimizer.minimize() error
if test_minimize:
try:
optimizer.minimize()
except:
print(
"====== Find sharding_stage2_optimizer.minimize() error ======")
return

train_reader = paddle.batch(reader_decorator(),
batch_size=batch_size,
drop_last=True)

train_loader = paddle.io.DataLoader.from_generator(capacity=32,
use_double_buffer=True,
iterable=True,
return_list=True,
use_multiprocess=True)
train_loader.set_sample_list_generator(train_reader)

if sharding_stage == 2:
model.to(device="gpu")

for eop in range(epoch):
model.train()

for batch_id, data in enumerate(train_loader()):
img, label = data
label.stop_gradient = True
img.stop_gradient = True

out = model(img)
loss = paddle.nn.functional.cross_entropy(input=out, label=label)

avg_loss = paddle.mean(x=loss.cast(dtype=paddle.float32))
if batch_size == 20:
avg_loss = avg_loss / 5
avg_loss.backward()

if not accumulate_grad:
optimizer.step()
optimizer.clear_grad()

if accumulate_grad:
optimizer.step()
optimizer.clear_grad()

if save_model:
return model, optimizer
return model.parameters()


def test_dp_stage2():
paddle.distributed.init_parallel_env()
mlp = MLP()
state_dict = mlp.state_dict()
mlp1 = MLP()
mlp2 = MLP()
mlp3 = MLP()
mlp4 = MLP()
mlp5 = MLP()
mlp6 = MLP()
mlp7 = MLP()
mlp1.set_state_dict(state_dict)
mlp2.set_state_dict(state_dict)
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
mlp5.set_state_dict(state_dict)
mlp6.set_state_dict(state_dict)
mlp7.set_state_dict(state_dict)

# DP VS stage2
dp_params = train_mlp(mlp1,
sharding_stage="dp",
use_pure_fp16=False,
opt_group=False)
stage2_params = train_mlp(mlp2,
sharding_stage=2,
use_pure_fp16=False,
opt_group=False)
for i in range(len(dp_params)):
np.testing.assert_allclose(dp_params[i].numpy(),
stage2_params[i].numpy(),
rtol=1e-6)

# stage2 accumulate grad
stage2_params = train_mlp(mlp3, sharding_stage=2, accumulate_grad=True)
stage2_accumulate_grad = train_mlp(mlp4,
sharding_stage=2,
batch_size=20,
accumulate_grad=True)
for i in range(len(stage2_params)):
np.testing.assert_allclose(stage2_params[i].numpy(),
stage2_accumulate_grad[i].numpy(),
rtol=1e-5,
atol=1e-5)

# stage2 param list VS param group
stage2_params = train_mlp(mlp5,
sharding_stage=2,
use_pure_fp16=False,
opt_group=True)
for i in range(len(dp_params)):
np.testing.assert_allclose(dp_params[i].numpy(),
stage2_params[i].numpy(),
rtol=1e-6)

# save/load model
output_dir = tempfile.mkdtemp()
model_file = os.path.join(output_dir, "model.pdmodel")
optimizer_file = os.path.join(output_dir, "model.pdopt")
model_stage2, optimizer_stage2 = train_mlp(mlp6,
sharding_stage=2,
use_pure_fp16=False,
opt_group=False,
save_model=True)
paddle.save(model_stage2.state_dict(), model_file)
paddle.save(optimizer_stage2.state_dict(), optimizer_file)
m_state_dict = paddle.load(model_file)
opt_state_dict = paddle.load(optimizer_file)
model_stage2.set_state_dict(m_state_dict)
optimizer_stage2.set_state_dict(opt_state_dict)
shutil.rmtree(output_dir)

# check optimizer.minimize() error
train_mlp(mlp7, sharding_stage=2, test_minimize=True)
return


if __name__ == '__main__':
with _test_eager_guard():
test_dp_stage2()
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def test_dygraph_sharding_stage2_offload(self):
self.run_mnist_2gpu('dygraph_sharding_stage2_offload.py',
eager_mode=False)

def test_dygraph_sharding_stage2_with_comm_overlap(self):
self.run_mnist_2gpu('dygraph_group_sharded_stage2_comm_overlap.py')


if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
Expand Down