Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hybrid Parallel] Support dp & mp in dygraph #32323

Merged
merged 7 commits into from
Apr 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 27 additions & 12 deletions python/paddle/distributed/fleet/base/fleet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from paddle.fluid.wrapped_decorator import wrap_decorator
from paddle.fluid.dygraph import parallel_helper
from . import topology as tp
from .topology import ParallelMode
from ..meta_parallel import ModelParallel
from ..meta_optimizers import HybridParallelOptimizer


def _inited_runtime_handler_(func):
Expand Down Expand Up @@ -219,6 +222,9 @@ def init(self, role_maker=None, is_collective=False, strategy=None):

if paddle.fluid.framework.in_dygraph_mode():
if self.worker_num() == 1:
# if worker_num is 1, should construct default topology & hcg
self._topology = tp.CommunicateTopology()
self._hcg = tp.HybridCommunicateGroup(self._topology)
return
if parallel_helper._is_parallel_ctx_initialized():
warnings.warn(
Expand Down Expand Up @@ -694,10 +700,12 @@ def distributed_optimizer(self, optimizer, strategy=None):

self._context = {}

# TODO(shenliang03): This is a temporary solution to support amp. In the case of a dynamic graph,
# the optimizer is returned directly. This problem will be fixed in the future.
if paddle.fluid.framework.in_dygraph_mode():
return optimizer
if self.worker_num() > 1:
return HybridParallelOptimizer(optimizer, self._hcg,
self._user_defined_strategy)
else:
return optimizer
return self

@dygraph_only
Expand Down Expand Up @@ -756,15 +764,22 @@ def forward(self, x):


"""
assert model is not None
self.model = paddle.DataParallel(
model,
comm_buffer_size=self._user_defined_strategy.fuse_grad_size_in_MB,
last_comm_buffer_size=self._user_defined_strategy.
last_comm_group_size_MB,
find_unused_parameters=self._user_defined_strategy.
find_unused_parameters)
return self.model
assert model is not None, "model should not be None"
if self.worker_num() <= 1:
return model
if self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
distributed_model = paddle.DataParallel(
model,
comm_buffer_size=self._user_defined_strategy.
fuse_grad_size_in_MB,
last_comm_buffer_size=self._user_defined_strategy.
last_comm_group_size_MB,
find_unused_parameters=self._user_defined_strategy.
find_unused_parameters)
elif self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL:
distributed_model = ModelParallel(
model, self._hcg, strategy=self._user_defined_strategy)
return distributed_model

@dygraph_only
def state_dict(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
meta_optimizer_names = list(
filter(lambda name: name.endswith("Optimizer"), dir()))

# Because HybridParallelOptimizer is dygraph optimizer, it
# should be removed
meta_optimizer_names.remove("HybridParallelOptimizer")


class MetaOptimizerFactory(object):
def __init__(self):
Expand Down
28 changes: 25 additions & 3 deletions python/paddle/distributed/fleet/base/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,16 @@
_HYBRID_PARALLEL_GROUP = None


class ParallelMode(object):
DATA_PARALLEL = 0
MODEL_PARALLEL = 1
PIPELINE_PARALLEL = 2


class CommunicateTopology(object):
def __init__(self, hybrid_group_names, dims):
def __init__(self,
hybrid_group_names=["data", "pipe", "model"],
dims=[1, 1, 1]):
self._parallel_names = hybrid_group_names
self._dims = dims
self.coordinate = collections.namedtuple('Coordinate',
Expand Down Expand Up @@ -118,15 +126,29 @@ def __init__(self, topology):

# create comm group for data parallel
self._dp_group, self._dp_comm_group = self._set_comm_group("data")
print("data parallel group", self._dp_group, file=sys.stderr)

# create comm group for model parallel
self._mp_group, self._mp_comm_group = self._set_comm_group("model")
print("data parallel group", self._mp_group, file=sys.stderr)
debug_str = "HybridParallelInfo: rank_id: %d, dp_degree: %d, " \
"mp_degree: %d, pp_degree: %d\n" % (self.global_rank, self._dp_degree,
self._mp_degree,self._pp_degree)
debug_str += "dp_group: %s, mp_group: %s" % (self._dp_group,
self._mp_group)
print(debug_str, file=sys.stderr)

global _HYBRID_PARALLEL_GROUP
_HYBRID_PARALLEL_GROUP = self

def get_parallel_mode(self):
# there are three modes : DataParallel / ModelParallel / PipelineParallel
if self._mp_degree == 1 and self._pp_degree == 1:
return ParallelMode.DATA_PARALLEL
elif self._mp_degree > 1 and self._pp_degree == 1:
# initialize the seed
return ParallelMode.MODEL_PARALLEL
elif self._pp_degree > 1:
return ParallelMode.PIPELINE_PARALLEL

def _check_vaild_topo(self):
return self._dp_degree * self._mp_degree * self._pp_degree == self.nranks

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from .lamb_optimizer import LambOptimizer
from .fp16_allreduce_optimizer import FP16AllReduceOptimizer
from .sharding_optimizer import ShardingOptimizer
from .dygraph_optimizer import HybridParallelOptimizer
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from .hybrid_parallel_optimizer import HybridParallelOptimizer
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.optimizer import Optimizer
from ...utils.hybrid_parallel_util import fused_allreduce_gradients
from ...base.topology import ParallelMode
from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid import framework
from paddle.fluid.framework import Variable


class HybridParallelOptimizer:
def __init__(self, optimizer, hcg, strategy):
self._inner_opt = optimizer
self._strategy = strategy
self._hcg = hcg
self._is_mp = (
self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL)
self._need_dp = (self._hcg.get_data_parallel_world_size() > 1)

@imperative_base.no_grad
@framework.dygraph_only
def step(self):
if self._is_mp and self._need_dp:
fused_allreduce_gradients(
list(self._inner_opt._parameter_list), self._hcg)
self._inner_opt.step()

@imperative_base.no_grad
def minimize(self,
loss,
startup_program=None,
parameters=None,
no_grad_set=None):
assert isinstance(loss, Variable), "The loss should be an Tensor."

parameter_list = parameters if parameters \
else self._parameter_list

if self._is_mp and self._need_dp:
fused_allreduce_gradients(list(parameter_list), self._hcg)

return self._inner_opt.minimize(loss, startup_program, parameters,
no_grad_set)

def __getattr__(self, item):
return getattr(self._inner_opt, item)
1 change: 1 addition & 0 deletions python/paddle/distributed/fleet/meta_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
# limitations under the License.

from .mp_utils import *
from .model_parallel import ModelParallel
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.fluid.dygraph.layers import Layer
import logging


class MetaParallelBase(Layer):
def __init__(self, layers, hcg, strategy):
super(MetaParallelBase,
self).__init__(layers.full_name() + "_meta_parallel_base")
self._layers = layers
self._hcg = hcg
self._prepare_for_model()

def _prepare_for_model(self):
pass

def _pre_forward(self, *inputs, **kwargs):
pass

def forward(self, *inputs, **kwargs):
self._pre_forward(*inputs, **kwargs)

output = self._layers(*inputs, **kwargs)

self._post_forward(output)

return output

def _post_forward(self, output):
pass
29 changes: 29 additions & 0 deletions python/paddle/distributed/fleet/meta_parallel/model_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.fluid.dygraph.layers import Layer
from .meta_parallel_base import MetaParallelBase
from ..utils.hybrid_parallel_util import *


class ModelParallel(MetaParallelBase):
def __init__(self, layers, hcg, **kwargs):
super(ModelParallel, self).__init__(layers, hcg, **kwargs)

def _prepare_for_model(self):
broadcast_mp_parameters(self._layers, self._hcg)
broadcast_dp_parameters(self._layers, self._hcg)

def _pre_forward(self, *inputs, **kwargs):
return broadcast_input_data(self._hcg, *inputs, **kwargs)
96 changes: 96 additions & 0 deletions python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import six
import numpy as np
import warnings

from paddle import framework
import paddle
from paddle.fluid import core
from paddle.fluid.dygraph.parallel import _split_tensors, sync_params_buffers, construct_groups
from collections import OrderedDict


def _apply_collective_grads(parameters, comm_group):
grad_var_set = set()
grad_vars = []
sparse_grad_vars = []

for param in parameters:
if param.trainable and (param._grad_ivar() is not None):
g_var = param._grad_ivar()
assert not g_var._is_sparse(
), "Now, it doesn't support sparse parameters"
grad_vars.append(g_var)
assert g_var not in grad_var_set
grad_var_set.add(g_var)

coalesced_grads_and_vars = construct_groups(grad_vars, 128 * 1024 * 1024)

for coalesced_grad, _, _ in coalesced_grads_and_vars:
# need to div nranks
coalesced_grad = coalesced_grad / comm_group.nranks
paddle.distributed.all_reduce(coalesced_grad, group=comm_group)

_split_tensors(coalesced_grads_and_vars)


def broadcast_input_data(hcg, *inputs, **kwargs):
model_parallel_group = hcg.get_model_parallel_group()
src_rank = hcg.get_model_parallel_group_src_rank()

for input_ in inputs:
if isinstance(input_, core.VarBase):
with framework.no_grad():
paddle.distributed.broadcast(
input_,
src=src_rank,
group=model_parallel_group,
use_calc_stream=True)
else:
print("it doesn't support data type {}".format(type(input_)))

for k, v in kwargs.items():
if isinstance(v, core.VarBase):
with framework.no_grad():
paddle.distributed.broadcast(
v,
src=src_rank,
group=model_parallel_group,
use_calc_stream=True)
kwargs[k] = v
else:
print("it doesn't support data type {}".format(type(v)))
return inputs, kwargs


def broadcast_mp_parameters(model, hcg):
model_parallel_group = hcg.get_model_parallel_group()
src_rank = hcg.get_model_parallel_group_src_rank()
sync_params_buffers(
model, model_parallel_group, src_rank, is_model_parallel=True)


def broadcast_dp_parameters(model, hcg):
data_parallel_group = hcg.get_data_parallel_group()
src_rank = hcg.get_data_parallel_group_src_rank()
sync_params_buffers(
model, data_parallel_group, src_rank, is_model_parallel=False)


def fused_allreduce_gradients(parameter_list, hcg):
data_parallel_group = hcg.get_data_parallel_group()
with framework.no_grad():
_apply_collective_grads(parameter_list, data_parallel_group)
Loading