Skip to content

Commit

Permalink
[dygraph sharding stage 2] sharding broadcast overlap (PaddlePaddle#4…
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Oct 17, 2022
1 parent a98a306 commit 1a286f9
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

import copy
import logging
import warnings

import numpy as np
from collections import OrderedDict

Expand Down Expand Up @@ -87,7 +89,7 @@ def __init__(self,
self._optim = optim

# sharing stage 2 comm overlap flag
self._comm_overlap = False
self._reduce_overlap = False
# record the last task used for comm overlap for sharding stage 2
self._comm_task = None

Expand All @@ -108,6 +110,17 @@ def __init__(self,
filter(lambda x: x.trainable and x.dtype == Type.fp16.value,
self._local_params))) > 0

self._broadcast_overlap = False
self._forward_pre_hook_remove_helper = []
try:
# The fp32 params such as layer_norm_0.w_0 will be at the end of param_list.
# Have to sort the params to make sure all params are in the forward using order.
self._broadcast_order_params = sorted(
self.local_params,
key=lambda x: int(x.name.split('.')[0].split('_')[-1]))
except ValueError:
self._broadcast_order_params = None

self._group = new_group(
_get_global_group().ranks) if group is None else group

Expand Down Expand Up @@ -163,15 +176,34 @@ def _sync_params_and_buffers(self):
sync_op=True)

def _update_task(self, task):
if self._comm_overlap:
if self._reduce_overlap:
assert task is not None
# Only track of the last reduce task.
# Since all tasks are on the same stream, only need to wait the last one.
# After waiting for the last reduce task, all reduce tasks before have already finished.
self._comm_task = task

def _set_comm_overlap(self, comm_overlap):
self._comm_overlap = comm_overlap
def _set_reduce_overlap(self, reduce_overlap):
# Enable gradients' reduces overlap with backward calculation.
self._reduce_overlap = reduce_overlap

def _set_broadcast_overlap(self, broadcast_overlap, layers=None):
# Enable post optimizer broadcasts overlap with the forward calculation of next batch.
self._broadcast_overlap = broadcast_overlap
if self._broadcast_overlap:
assert layers is not None, \
"To enable broadcast overlap forward, please pass the module to the function."
self._layers = layers
warnings.warn(
"Setting overlap broadcast means the `paddle.device.cuda.synchronize()` "
"must be called manually before calling `paddle.save()` and before and inference."
)
if self._broadcast_order_params is None:
# Params' names should be like column_linear_32.w_0 patter to get the best performance.
warnings.warn(
"The param name passed to the optimizer doesn't follow .+_[0-9]+\..+ patter, "
"overlap broadcast may harm the performance.")
self._broadcast_order_params = self._local_params

def _generate_master_params(self, trainable_params):
if self.offload:
Expand Down Expand Up @@ -382,6 +414,12 @@ def step(self):
"""
# This method won't be called directly by opt.step()!
# The _redefine_opt_step() in class GroupShardedStage2 will wrap this function.
if self._broadcast_overlap:
# Clear the pre forward hook in the optimizer step.
for hook_remove in self._forward_pre_hook_remove_helper:
hook_remove.remove()
self._forward_pre_hook_remove_helper = []

if self.offload:
params_list = [self.offload_params.buffer]

Expand Down Expand Up @@ -425,9 +463,49 @@ def _broadcast_params(self):
"""Broadcast the parameters of the current rank to each rank"""

# Exchange all the shards with the other ranks
for dtype_per_rank in self.param_storages.values():
for dst_rank, internal_storage in dtype_per_rank.items():
broadcast(tensor=internal_storage.buffer,
src=self._group.ranks[dst_rank],
group=self._group,
sync_op=True)
if self._broadcast_overlap:
self._broadcast_params_overlap_forward()
else:
for dtype_per_rank in self.param_storages.values():
for dst_rank, internal_storage in dtype_per_rank.items():
broadcast(tensor=internal_storage.buffer,
src=self._group.ranks[dst_rank],
group=self._group,
sync_op=True)

def _forward_pre_hook_function(self, tasks):
# Since the layers will call pre hook by `forward_pre_hook(self, inputs)`,
# the helper functions needs the x and y to take those params.
def __impl__(x, y):
for task in tasks:
# Wait for broadcast task before using the result of the broadcast.
task.wait()

return __impl__

@paddle.autograd.no_grad()
def _broadcast_params_overlap_forward(self):
# Exchange all the shards with the other ranks,
# but overlap the broadcast with next batch's calculation.
param2task = {}
for x in self._broadcast_order_params:
if x.trainable:
task = broadcast(
tensor=x,
src=self._group.ranks[self._param2rank[x.name]],
group=self._group,
sync_op=False)
assert x.name not in param2task
param2task[x.name] = task

for layer in self._layers.sublayers():
if len(layer.sublayers()) == 0:
# Register forward pre hood for leaf layers. This will get the best performance.
tasks = []
for param in layer.parameters():
if param.trainable:
if param.name in param2task:
tasks.append(param2task[param.name])
self._forward_pre_hook_remove_helper.append(
layer.register_forward_pre_hook(
self._forward_pre_hook_function(tasks)))
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
self._all_params.extend(list(optim.local_params))

# sharing stage 2 comm overlap flag
self._comm_overlap = False
self._reduce_overlap = False

self._trainable_params = []
self._grad_reduced = []
Expand Down Expand Up @@ -309,17 +309,17 @@ def _clear_counters(self):
for grad_storage in self._grad_storage_list:
grad_storage.reset_checked_in()

def _set_comm_overlap(self, comm_overlap):
def _set_reduce_overlap(self, reduce_overlap):
# Hacky way to not add an extra parameter to the `group_sharded_parallel` funct.
# User should use this like:
# model, optimizer, scaler = group_sharded_parallel(...)
# model._set_comm_overlap(True)
self._comm_overlap = comm_overlap
if self._comm_overlap:
# model._set_reduce_overlap(True)
self._reduce_overlap = reduce_overlap
if self._reduce_overlap:
assert len(
self._sharding_optimizers
) == 1, "Only support comm overlap strategy for single optimizer"
self._sharding_optimizers[0]._set_comm_overlap(comm_overlap)
self._sharding_optimizers[0]._set_reduce_overlap(reduce_overlap)

def _get_reduce_fn(self, index, param, dst_rank):
"""
Expand Down Expand Up @@ -357,7 +357,7 @@ def cleanup():
collective.reduce(tensor=param.grad,
dst=self._group.ranks[dst_rank],
group=self._group,
sync_op=not self._comm_overlap))
sync_op=not self._reduce_overlap))

# Clear the task flow and trigger callback to clear the redundant gradient
# self._clear_task_flow()
Expand Down Expand Up @@ -407,7 +407,7 @@ def cleanup():
tensor=grad_storage.buffer,
dst=self._group.ranks[grad_storage.destination],
group=self._group,
sync_op=not self._comm_overlap))
sync_op=not self._reduce_overlap))

cleanup()

Expand Down Expand Up @@ -545,7 +545,7 @@ def _redefine_opt_step(self):
opt_step = opt.step

def _opt_step(self):
if self._comm_overlap:
if self._reduce_overlap:
# Wait for the last reduce task. This wait must before grad scale function.
assert self._comm_task is not None
self._comm_task.wait()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,15 @@ def train_mlp(model,
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)

if sharding_stage == 2:
origin_model = model
optimizer = GroupShardedOptimizerStage2(
params=optimizer._parameter_list, optim=optimizer, group=group)
model = GroupShardedStage2(model,
optimizer,
group=group,
buffer_max_size=2**21)
model._set_comm_overlap(True)
model._set_reduce_overlap(True)
optimizer._set_broadcast_overlap(True, model)
else:
model = paddle.DataParallel(model)

Expand Down Expand Up @@ -154,6 +156,8 @@ def train_mlp(model,
optimizer.step()
optimizer.clear_grad()

paddle.device.cuda.synchronize()

if save_model:
return model, optimizer
return model.parameters()
Expand Down

0 comments on commit 1a286f9

Please sign in to comment.