From 362baa3820725e599163039ecf2c9e0e0e0a2375 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Tue, 23 Aug 2022 08:50:44 +0000 Subject: [PATCH 01/31] add recompute_sequential --- python/paddle/distributed/fleet/model.py | 53 +++++-------------- .../distributed/fleet/utils/recompute.py | 53 +++++++++++++++++++ 2 files changed, 66 insertions(+), 40 deletions(-) diff --git a/python/paddle/distributed/fleet/model.py b/python/paddle/distributed/fleet/model.py index 988d2d928cc2b..27c4c7be72ee0 100644 --- a/python/paddle/distributed/fleet/model.py +++ b/python/paddle/distributed/fleet/model.py @@ -20,46 +20,10 @@ from .meta_parallel import TensorParallel, model_parallel_random_seed from .meta_parallel import PipelineParallel, ShardingParallel from paddle.fluid import core -from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction +from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _initialize_recompute_setting, _initialize_recompute_hcg from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar from paddle.distributed import fleet - -class _RecomputeModelWrapper(paddle.nn.Layer): - - def __init__(self, model, segments=2, preserve_rng_state=True): - super(_RecomputeModelWrapper, self).__init__() - assert isinstance(model, paddle.nn.Sequential), ( - "The model passed to RecomputeModelWrapper must be of type " - "paddle.nn.Sequential.") - self._model = model - self._segments = segments - self._preserve_rng_state = preserve_rng_state - self._layers = list(model.children()) - self._segment_size = len(self._layers) // segments - - def _run_func(self, begin, end): - - def do_run(input): - for i in range(begin, end): - input = self._layers[i](input) - return input - - return do_run - - def _checkpoint(self, func, *args, **kwargs): - return LegacyRecomputeFunction.apply(func, self._preserve_rng_state, - *args) - - def forward(self, input): - end = 0 - for begin in range(0, self._segment_size * (self._segments - 1), - self._segment_size): - end = begin + self._segment_size - input = self._checkpoint(self._run_func(begin, end), input) - return self._run_func(end, len(self._layers))(input) - - _grad_scalar = None @@ -125,7 +89,6 @@ def forward(self, x): return model amp_enable = False - recompute_enable = False strategy = fleet_env._user_defined_strategy if strategy.amp == True: amp_enable = True @@ -155,8 +118,18 @@ def forward(self, x): use_dynamic_loss_scaling=use_dynamic_loss_scaling) if strategy.recompute == True: - recompute_enable = True - model = _RecomputeModelWrapper(model) + #NOTE when in hybrid parallel, init global recompute env. + if fleet_env._hcg.get_parallel_mode() in [ + ParallelMode.TENSOR_PARALLEL, ParallelMode.PIPELINE_PARALLEL + ]: + _initialize_recompute_hcg(fleet_env._hcg) + + keys = strategy.recompute_configs.keys() + enable_offload = keys["enable_offload"] if "enable_offload" in keys( + ) else False + enable_partition = keys[ + "enable_partition"] if "enable_partition" in keys() else False + _initialize_recompute_setting(enable_offload, enable_partition) if strategy.heter_ccl_mode == True: distributed_model = paddle.DataParallel( diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index f0c74159488a7..d6b4e6ba12f6b 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -474,3 +474,56 @@ def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): check_recompute_necessary(args) return RecomputeFunction.apply(function, preserve, *args) + + +def recompute_sequential(functions, segments, input, **kwargs): + """ + recompute intermediate activations to save then memory for 'Sequential' models. + + Parameters: + functions(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model + whose intermediate activations will be released to save memory in forward stage and will be recomputed + in backward stage for gradient calculation. + segments(int): Number of chunks to create in the model + input(Tensor): inputs to the function. + **kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to + indicate whether to save the forward rng. If it is True, then the last forward rng value will be + restored when the forward recalculation of backpropagation is performed. The default + preserve_rng_state is True. + + Returns: + Output of function on args. + + Examples: + .. code-block:: python + + model = paddle.nn.Sequential(...) + input = recompute_sequential(model, segments, input) + """ + preserve = kwargs.pop('preserve_rng_state', True) + if kwargs: + raise ValueError("Unexpected keyword arguments: " + + ",".join(arg for arg in kwargs)) + + def _run_func(begin, end, functions): + + def do_run(input): + for i in range(begin, end): + input = functions[i](input) + return input + + return do_run + + assert isinstance(functions, paddle.nn.Sequential), ( + "The functions must be of type paddle.nn.Sequential.") + + functions = list(functions.children()) + segment_size = len(functions) // segments + + end = 0 + for begin in range(0, segment_size * (segments - 1), segment_size): + end = begin + segment_size + input = recompute(_run_func(begin, end), + input, + preserve_rng_state=preserve) + return _run_func(end, len(functions))(input) From 49d246306f9b6decd30f410b68089d35d4555f0d Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Tue, 23 Aug 2022 10:16:41 +0000 Subject: [PATCH 02/31] add recompute dirs. --- .../paddle/distributed/fleet/{utils => recompute}/recompute.py | 0 python/paddle/distributed/fleet/utils/__init__.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename python/paddle/distributed/fleet/{utils => recompute}/recompute.py (100%) diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py similarity index 100% rename from python/paddle/distributed/fleet/utils/recompute.py rename to python/paddle/distributed/fleet/recompute/recompute.py diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 1bf90a22e375c..582bf3f70f0e9 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -15,7 +15,7 @@ from .fs import LocalFS # noqa: F401 from .fs import HDFSClient # noqa: F401 from .ps_util import DistributedInfer # noqa: F401 -from .recompute import recompute # noqa: F401 +from ..recompute.recompute import recompute # noqa: F401 from . import log_util # noqa: F401 from . import hybrid_parallel_util # noqa: F401 From d2b25197523c40aea980149b1cfd785d20c7e6fd Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Tue, 23 Aug 2022 10:32:11 +0000 Subject: [PATCH 03/31] update cite. --- python/paddle/distributed/fleet/__init__.py | 2 ++ .../distributed/fleet/recompute/recompute.py | 16 ++++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index 11d7643c676dd..bf92602dade5c 100755 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -32,6 +32,7 @@ from .model import distributed_model from .optimizer import distributed_optimizer from .scaler import distributed_scaler +from .recompute.recompute import recompute __all__ = [ #noqa "CommunicateTopology", "UtilBase", "HybridCommunicateGroup", @@ -90,3 +91,4 @@ shrink = fleet.shrink get_hybrid_communicate_group = fleet.get_hybrid_communicate_group distributed_scaler = distributed_scaler +recompute = recompute diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index d6b4e6ba12f6b..0656c4cea9b66 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -20,6 +20,8 @@ from paddle.fluid import framework import contextlib from paddle.fluid.framework import in_dygraph_mode +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute import logging @@ -30,7 +32,7 @@ ch.setFormatter(formatter) logger.addHandler(ch) -__all__ = [] +__all__ = ["recompute", "recompute_sequential"] def detach_variable(inputs): @@ -473,7 +475,17 @@ def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): if framework._dygraph_tracer()._has_grad: check_recompute_necessary(args) - return RecomputeFunction.apply(function, preserve, *args) + fleet_env = fleet.fleet + #NOTE: when in hybrid parallel, recompute supports offload and partition function, config it in DistributedStrategy firstly. + if hasattr(fleet_env, "_hcg") and fleet_env._hcg.get_parallel_mode() in [ + ParallelMode.TENSOR_PARALLEL, ParallelMode.PIPELINE_PARALLEL + ]: + # global env var in _hp_recompute + global _hcg + assert _hcg is not None, "please init recompute env in hybrid parallel by add recompute_config in DistributedStrategy firstly." + return _hp_recompute(function, *args) + else: + return RecomputeFunction.apply(function, preserve, *args) def recompute_sequential(functions, segments, input, **kwargs): From b15ce9f07f46994ebf4bb6df19427814d89fa702 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Tue, 23 Aug 2022 10:35:39 +0000 Subject: [PATCH 04/31] update. --- python/paddle/distributed/fleet/recompute/recompute.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index 0656c4cea9b66..dbcff13123164 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -485,6 +485,7 @@ def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): assert _hcg is not None, "please init recompute env in hybrid parallel by add recompute_config in DistributedStrategy firstly." return _hp_recompute(function, *args) else: + # when in pure data parallel or non-parallel training, use simple recompute. return RecomputeFunction.apply(function, preserve, *args) From 18462ac8059fc7998e3b78d872bed8cb63cd3fe0 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Tue, 23 Aug 2022 10:46:32 +0000 Subject: [PATCH 05/31] recompute unify. --- python/paddle/distributed/fleet/__init__.py | 3 ++- .../paddle/distributed/fleet/recompute/recompute.py | 13 ++++++------- python/paddle/distributed/fleet/utils/__init__.py | 5 +++-- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index bf92602dade5c..c1762024a4e34 100755 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -32,7 +32,7 @@ from .model import distributed_model from .optimizer import distributed_optimizer from .scaler import distributed_scaler -from .recompute.recompute import recompute +from .recompute.recompute import recompute, recompute_sequential __all__ = [ #noqa "CommunicateTopology", "UtilBase", "HybridCommunicateGroup", @@ -92,3 +92,4 @@ get_hybrid_communicate_group = fleet.get_hybrid_communicate_group distributed_scaler = distributed_scaler recompute = recompute +recompute_sequential = recompute_sequential diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index dbcff13123164..6fe9bb956fd58 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -476,16 +476,15 @@ def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): check_recompute_necessary(args) fleet_env = fleet.fleet + strategy = fleet_env._user_defined_strategy #NOTE: when in hybrid parallel, recompute supports offload and partition function, config it in DistributedStrategy firstly. - if hasattr(fleet_env, "_hcg") and fleet_env._hcg.get_parallel_mode() in [ - ParallelMode.TENSOR_PARALLEL, ParallelMode.PIPELINE_PARALLEL - ]: - # global env var in _hp_recompute - global _hcg - assert _hcg is not None, "please init recompute env in hybrid parallel by add recompute_config in DistributedStrategy firstly." + if hasattr( + fleet_env, + "_hcg") and strategy.recompute and fleet_env._hcg.get_parallel_mode( + ) in [ParallelMode.TENSOR_PARALLEL, ParallelMode.PIPELINE_PARALLEL]: return _hp_recompute(function, *args) else: - # when in pure data parallel or non-parallel training, use simple recompute. + # when in pure data parallel or non-parallel training or strategy.recompute is False, use simple recompute. return RecomputeFunction.apply(function, preserve, *args) diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 582bf3f70f0e9..4a6257af09ebe 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -15,11 +15,12 @@ from .fs import LocalFS # noqa: F401 from .fs import HDFSClient # noqa: F401 from .ps_util import DistributedInfer # noqa: F401 -from ..recompute.recompute import recompute # noqa: F401 +from ..recompute.recompute import recompute, recompute_sequential # noqa: F401 from . import log_util # noqa: F401 from . import hybrid_parallel_util # noqa: F401 __all__ = [ #noqa - "LocalFS", "recompute", "DistributedInfer", "HDFSClient" + "LocalFS", "recompute", "recompute_sequential", "DistributedInfer", + "HDFSClient" ] From 91bb00f7aa309e2be855989ffff764e685f05ed9 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Tue, 23 Aug 2022 12:37:28 +0000 Subject: [PATCH 06/31] update recompute. --- python/paddle/distributed/fleet/__init__.py | 2 +- .../fleet/meta_parallel/pp_utils/utils.py | 2 +- .../distributed/fleet/recompute/__init__.py | 20 +++++++++++++++++++ .../distributed/fleet/recompute/recompute.py | 5 ++++- .../distributed/fleet/utils/__init__.py | 2 +- 5 files changed, 27 insertions(+), 4 deletions(-) create mode 100644 python/paddle/distributed/fleet/recompute/__init__.py diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index c1762024a4e34..bd37a72087700 100755 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -32,7 +32,7 @@ from .model import distributed_model from .optimizer import distributed_optimizer from .scaler import distributed_scaler -from .recompute.recompute import recompute, recompute_sequential +from .recompute import recompute, recompute_sequential __all__ = [ #noqa "CommunicateTopology", "UtilBase", "HybridCommunicateGroup", diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index 46fe7e641733a..e30ba0ff7af05 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -19,7 +19,7 @@ from paddle import _C_ops from paddle.autograd import PyLayer from paddle.fluid import framework -from ...utils.recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker +from ...recompute.recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker from ..parallel_layers.random import get_rng_state_tracker from paddle.fluid.framework import in_dygraph_mode diff --git a/python/paddle/distributed/fleet/recompute/__init__.py b/python/paddle/distributed/fleet/recompute/__init__.py new file mode 100644 index 0000000000000..e9aeaa66f3ca8 --- /dev/null +++ b/python/paddle/distributed/fleet/recompute/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .recompute import recompute, recompute_sequential, swith_rng_state_tracker, check_recompute_necessary, detach_variable + +__all__ = [ + "recompute", "recompute_sequential", "swith_rng_state_tracker", + "check_recompute_necessary", "detach_variable" +] diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index 6fe9bb956fd58..8d3959751235f 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -32,7 +32,10 @@ ch.setFormatter(formatter) logger.addHandler(ch) -__all__ = ["recompute", "recompute_sequential"] +__all__ = [ + "recompute", "recompute_sequential", "swith_rng_state_tracker", + "check_recompute_necessary", "detach_variable" +] def detach_variable(inputs): diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 4a6257af09ebe..34b911b90b596 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -15,7 +15,7 @@ from .fs import LocalFS # noqa: F401 from .fs import HDFSClient # noqa: F401 from .ps_util import DistributedInfer # noqa: F401 -from ..recompute.recompute import recompute, recompute_sequential # noqa: F401 +from ..recompute import recompute, recompute_sequential # noqa: F401 from . import log_util # noqa: F401 from . import hybrid_parallel_util # noqa: F401 From a249601abd450e8d23a6de4dcbc3c26207234af3 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Wed, 24 Aug 2022 01:35:42 +0000 Subject: [PATCH 07/31] update recompute. --- .../distributed/fleet/meta_parallel/pp_utils/utils.py | 10 +++++----- python/paddle/distributed/fleet/recompute/recompute.py | 7 +++---- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index e30ba0ff7af05..9c7d845354b18 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -19,7 +19,6 @@ from paddle import _C_ops from paddle.autograd import PyLayer from paddle.fluid import framework -from ...recompute.recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker from ..parallel_layers.random import get_rng_state_tracker from paddle.fluid.framework import in_dygraph_mode @@ -162,7 +161,7 @@ class _HPRecomputeFunction(PyLayer): @staticmethod def forward(ctx, run_function, all_outputs, *args): - check_recompute_necessary(args) + paddle.distributed.fleet.recompute.check_recompute_necessary(args) # store for recomputing ctx.run_function = run_function @@ -247,13 +246,14 @@ def backward(ctx, *args): tracer._has_grad = True # need restore auto_cast state as well as w/b list - with swith_rng_state_tracker(ctx.fwd_cuda_rng_state, - ctx.fwd_cuda_rng_state_tracker): + with paddle.distributed.fleet.recompute.swith_rng_state_tracker( + ctx.fwd_cuda_rng_state, ctx.fwd_cuda_rng_state_tracker): with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, custom_black_list=ctx.amp_black_list, level=ctx.amp_level): - detached_inputs = detach_variable(tuple(inputs)) + detached_inputs = paddle.distributed.fleet.recompute.detach_variable( + tuple(inputs)) outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, (core.VarBase, core.eager.Tensor)): diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index 8d3959751235f..b1450620eeb84 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -479,12 +479,11 @@ def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): check_recompute_necessary(args) fleet_env = fleet.fleet - strategy = fleet_env._user_defined_strategy #NOTE: when in hybrid parallel, recompute supports offload and partition function, config it in DistributedStrategy firstly. if hasattr( - fleet_env, - "_hcg") and strategy.recompute and fleet_env._hcg.get_parallel_mode( - ) in [ParallelMode.TENSOR_PARALLEL, ParallelMode.PIPELINE_PARALLEL]: + fleet_env, "_hcg" + ) and fleet_env._user_defined_strategy.recompute and fleet_env._hcg.get_parallel_mode( + ) in [ParallelMode.TENSOR_PARALLEL, ParallelMode.PIPELINE_PARALLEL]: return _hp_recompute(function, *args) else: # when in pure data parallel or non-parallel training or strategy.recompute is False, use simple recompute. From 123918048478533df6643fd91394cb0cf4f5bbde Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Wed, 24 Aug 2022 12:59:49 +0000 Subject: [PATCH 08/31] refact recompute. --- python/paddle/distributed/fleet/__init__.py | 3 +- .../parallel_layers/pp_layers.py | 9 ++- .../fleet/meta_parallel/pipeline_parallel.py | 3 - .../fleet/meta_parallel/pp_utils/__init__.py | 2 - .../pp_utils/p2p_communication.py | 4 +- .../meta_parallel/sharding/sharding_stage3.py | 2 +- python/paddle/distributed/fleet/model.py | 15 ---- .../distributed/fleet/recompute/__init__.py | 5 +- .../hybrid_recompute.py} | 80 ++++++++++--------- .../distributed/fleet/recompute/recompute.py | 13 +-- .../distributed/models/moe/moe_layer.py | 11 ++- 11 files changed, 64 insertions(+), 83 deletions(-) rename python/paddle/distributed/fleet/{meta_parallel/pp_utils/utils.py => recompute/hybrid_recompute.py} (81%) diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index bd37a72087700..794198598f459 100755 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -32,7 +32,7 @@ from .model import distributed_model from .optimizer import distributed_optimizer from .scaler import distributed_scaler -from .recompute import recompute, recompute_sequential +from .recompute import recompute, recompute_sequential, hybrid_recompute __all__ = [ #noqa "CommunicateTopology", "UtilBase", "HybridCommunicateGroup", @@ -93,3 +93,4 @@ distributed_scaler = distributed_scaler recompute = recompute recompute_sequential = recompute_sequential +hybrid_recompute = hybrid_recompute diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 58b0515e0bac8..67e6cf22a977e 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -49,7 +49,7 @@ import paddle from paddle.fluid.dygraph.layers import Layer from ...utils.log_util import logger, layer_to_str -from ..pp_utils.utils import _hp_recompute, _initialize_recompute_setting +from paddle.distributed import fleet from paddle.fluid.framework import in_dygraph_mode __all__ = [] @@ -190,7 +190,6 @@ def __init__(self, logger.info( "Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}" .format(recompute_offload, recompute_partition)) - _initialize_recompute_setting(recompute_offload, recompute_partition) world_size = dist.get_world_size() self.global_rank = dist.get_rank() @@ -402,8 +401,10 @@ def forward(self, input): input = (input, ) if self._need_recompute(funcs, input): - input = _hp_recompute( - self.forward_function(start_idx, end_idx), *input) + input = fleet.hybrid_recompute( + self.forward_function(start_idx, + end_idx), self._recompute_offload, + self._recompute_partition, *input) else: input = self.forward_function(start_idx, end_idx)(*input) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 3135c5379e880..6d8826febb731 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -14,7 +14,6 @@ import paddle import paddle.fluid as fluid from .meta_parallel_base import MetaParallelBase -from .pp_utils.utils import is_float_tensor, _initialize_recompute_hcg from .parallel_layers.pp_layers import PipelineLayer from ..utils.hybrid_parallel_util import broadcast_mp_parameters @@ -55,8 +54,6 @@ def __init__(self, layers, hcg, strategy): p2p.initialize_p2p_groups(hcg, self._using_cache) - _initialize_recompute_hcg(hcg) - self.is_first_stage = self.stage_id == 0 self.is_last_stage = (self.stage_id == (self.num_stages - 1)) self.global_rank = self._hcg.get_global_rank() diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py index 786eb20487a52..04575bfb23194 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py @@ -12,6 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .utils import get_tensor_bytes - __all__ = [] diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index f42752c5e8f1b..452ff18fd0a48 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -13,7 +13,6 @@ # limitations under the License. import paddle -from .utils import paddle_2_number, number_2_dtype from ...utils.log_util import logger import numpy as np from paddle import _C_ops @@ -103,6 +102,7 @@ def recv_meta(self, group): self.recv_stop_gradient = tuple(stop_grads) def _send_dims_shape_dtype(self, tensor, group): + from ...recompute.hybrid_recompute import paddle_2_number # send len(shape) dims = paddle.to_tensor(len(tensor.shape)) dst_rank = group.ranks[1] @@ -143,6 +143,7 @@ def send_meta(self, tensor, group): self._send_dims_shape_dtype(d, group=group) def set_send_message(self, tensor): + from ...recompute.hybrid_recompute import paddle_2_number if isinstance(tensor, (paddle.Tensor, core.eager.Tensor)): self.send_shape_message = tensor.shape self.send_dtype_message = paddle_2_number(tensor.dtype) @@ -274,6 +275,7 @@ def allgather_partial(tensor, def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): + from ...recompute.hybrid_recompute import number_2_dtype global _hcg tensor_recv_prev = None diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index 67d48c8abba1b..61b9e790474cd 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -33,7 +33,7 @@ from paddle.distributed.collective import _get_global_group from .sharding_utils import Type, ShardingClipGrad, device_guard -from ..pp_utils.utils import _all_gather +from ...recompute.hybrid_recompute import _all_gather from ...utils.internal_storage import GradStorage # CUDA alignment 256 bytes diff --git a/python/paddle/distributed/fleet/model.py b/python/paddle/distributed/fleet/model.py index 27c4c7be72ee0..e35f577af0b67 100644 --- a/python/paddle/distributed/fleet/model.py +++ b/python/paddle/distributed/fleet/model.py @@ -20,7 +20,6 @@ from .meta_parallel import TensorParallel, model_parallel_random_seed from .meta_parallel import PipelineParallel, ShardingParallel from paddle.fluid import core -from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _initialize_recompute_setting, _initialize_recompute_hcg from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar from paddle.distributed import fleet @@ -117,20 +116,6 @@ def forward(self, x): decr_every_n_nan_or_inf=decr_every_n_nan_or_inf, use_dynamic_loss_scaling=use_dynamic_loss_scaling) - if strategy.recompute == True: - #NOTE when in hybrid parallel, init global recompute env. - if fleet_env._hcg.get_parallel_mode() in [ - ParallelMode.TENSOR_PARALLEL, ParallelMode.PIPELINE_PARALLEL - ]: - _initialize_recompute_hcg(fleet_env._hcg) - - keys = strategy.recompute_configs.keys() - enable_offload = keys["enable_offload"] if "enable_offload" in keys( - ) else False - enable_partition = keys[ - "enable_partition"] if "enable_partition" in keys() else False - _initialize_recompute_setting(enable_offload, enable_partition) - if strategy.heter_ccl_mode == True: distributed_model = paddle.DataParallel( model, diff --git a/python/paddle/distributed/fleet/recompute/__init__.py b/python/paddle/distributed/fleet/recompute/__init__.py index e9aeaa66f3ca8..f9597f7a11f79 100644 --- a/python/paddle/distributed/fleet/recompute/__init__.py +++ b/python/paddle/distributed/fleet/recompute/__init__.py @@ -13,8 +13,9 @@ # limitations under the License. from .recompute import recompute, recompute_sequential, swith_rng_state_tracker, check_recompute_necessary, detach_variable +from .hybrid_recompute import hybrid_recompute __all__ = [ - "recompute", "recompute_sequential", "swith_rng_state_tracker", - "check_recompute_necessary", "detach_variable" + "recompute", "recompute_sequential", "hybrid_recompute", + "swith_rng_state_tracker", "check_recompute_necessary", "detach_variable" ] diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/recompute/hybrid_recompute.py similarity index 81% rename from python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py rename to python/paddle/distributed/fleet/recompute/hybrid_recompute.py index 9c7d845354b18..3ab796ba817ff 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/recompute/hybrid_recompute.py @@ -19,10 +19,15 @@ from paddle import _C_ops from paddle.autograd import PyLayer from paddle.fluid import framework -from ..parallel_layers.random import get_rng_state_tracker +from ..meta_parallel.parallel_layers.random import get_rng_state_tracker from paddle.fluid.framework import in_dygraph_mode +from paddle.distributed import fleet +from .recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker -__all__ = [] +__all__ = [ + "hybrid_recompute", "is_float_tensor", "get_tensor_dtype", + "paddle_2_number", "get_tensor_bytes" +] FLOAT_TYPE_DICT = { paddle.float16: "float16", @@ -87,23 +92,6 @@ def get_tensor_bytes(tensor): return tensor.numel() * elem_size -_hcg = None -_recompute_offload = False -_recompute_partition = False - - -def _initialize_recompute_setting(is_offload, is_partition): - global _recompute_offload, _recompute_partition - - _recompute_offload = is_offload - _recompute_partition = is_partition - - -def _initialize_recompute_hcg(hcg): - global _hcg - _hcg = hcg - - def _all_gather(tensor, group=None, use_calc_stream=True): """ The main difference with paddle.distributed.all_gather: @@ -119,10 +107,9 @@ def _all_gather(tensor, group=None, use_calc_stream=True): def _split_activation(tensor): - global _hcg - mp_degree = _hcg.get_model_parallel_world_size() - mp_rank = _hcg.get_model_parallel_rank() + mp_degree = fleet.fleet._hcg.get_model_parallel_world_size() + mp_rank = fleet.fleet._hcg.get_model_parallel_rank() if mp_degree < 2: return tensor @@ -141,10 +128,9 @@ def _split_activation(tensor): def _merge_activation(tensor): - global _hcg - mp_degree = _hcg.get_model_parallel_world_size() - mp_rank = _hcg.get_model_parallel_rank() - mp_group = _hcg.get_model_parallel_group() + mp_degree = fleet.fleet._hcg.get_model_parallel_world_size() + mp_rank = fleet.fleet._hcg.get_model_parallel_rank() + mp_group = fleet.fleet._hcg.get_model_parallel_group() if mp_degree < 2: return tensor return _all_gather(tensor, group=mp_group) @@ -160,8 +146,8 @@ class _HPRecomputeFunction(PyLayer): """ @staticmethod - def forward(ctx, run_function, all_outputs, *args): - paddle.distributed.fleet.recompute.check_recompute_necessary(args) + def forward(ctx, run_function, all_outputs, offload, partition, *args): + check_recompute_necessary(args) # store for recomputing ctx.run_function = run_function @@ -171,6 +157,10 @@ def forward(ctx, run_function, all_outputs, *args): ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( ).get_states_tracker() + # save config info + ctx.offload = offload + ctx.partition = partition + # save input for backward ctx.inputs = [] ctx.tensor_indices = [] @@ -200,13 +190,13 @@ def forward(ctx, run_function, all_outputs, *args): for i, arg in enumerate(args): if paddle.is_tensor(arg): state = arg.stop_gradient - if _recompute_partition: + if partition: ctx.tensor_shapes.append(arg.shape) partition = _split_activation(arg.detach()).clone() # TODO(shenliang03) not use calculate stream to D2H to speed - arg = partition.cpu() if _recompute_offload else partition + arg = partition.cpu() if offload else partition else: - arg = arg.cpu() if _recompute_offload else arg + arg = arg.cpu() if offload else arg arg.stop_gradient = state tensor_inputs.append(arg) ctx.tensor_indices.append(i) @@ -234,26 +224,25 @@ def backward(ctx, *args): device_id = paddle.distributed.ParallelEnv().device_id for i, idx in enumerate(tensor_indices): - if _recompute_partition: + if ctx.partition: state = tensors[i].stop_gradient tensors[i] = _merge_activation( tensors[i]).detach().reshape_(tensor_shapes[i]) tensors[i].stop_gradient = state inputs[idx] = tensors[i].cuda( - device_id) if _recompute_offload else tensors[i] + device_id) if ctx.offload else tensors[i] tracer = framework._dygraph_tracer() tracer._has_grad = True # need restore auto_cast state as well as w/b list - with paddle.distributed.fleet.recompute.swith_rng_state_tracker( - ctx.fwd_cuda_rng_state, ctx.fwd_cuda_rng_state_tracker): + with swith_rng_state_tracker(ctx.fwd_cuda_rng_state, + ctx.fwd_cuda_rng_state_tracker): with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, custom_black_list=ctx.amp_black_list, level=ctx.amp_level): - detached_inputs = paddle.distributed.fleet.recompute.detach_variable( - tuple(inputs)) + detached_inputs = detach_variable(tuple(inputs)) outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, (core.VarBase, core.eager.Tensor)): @@ -283,15 +272,28 @@ def backward(ctx, *args): return grads -def _hp_recompute(function, *args): +def hybrid_recompute(function, offload, partition, *args): + """ # NODTE(shenliang03)The current hybrid parallel recompute has limitations. # It cannot handle the following situations: # 1. The calculation output of recompute, there are tensors that do not require gradients. # 2. The forward output tensor has no gradient. This problem can be solved temporarily by detach(). # 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor + Parameters: + function(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model + whose intermediate activations will be released to save memory in forward stage and will be recomputed + in backward stage for gradient calculation. + offload(bool): whether to offload checkpoint. + partition: whether to split activation into each rank in model group. + *args(Tensor): inputs to the function. + Returns: + Output of function on args. + + """ + all_outputs = [] - _HPRecomputeFunction.apply(function, all_outputs, *args) + _HPRecomputeFunction.apply(function, all_outputs, offload, partition, *args) if len(all_outputs) == 1: return all_outputs[0] diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index b1450620eeb84..c0ff3179ab74b 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -20,8 +20,6 @@ from paddle.fluid import framework import contextlib from paddle.fluid.framework import in_dygraph_mode -from paddle.distributed import fleet -from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute import logging @@ -478,16 +476,7 @@ def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): if framework._dygraph_tracer()._has_grad: check_recompute_necessary(args) - fleet_env = fleet.fleet - #NOTE: when in hybrid parallel, recompute supports offload and partition function, config it in DistributedStrategy firstly. - if hasattr( - fleet_env, "_hcg" - ) and fleet_env._user_defined_strategy.recompute and fleet_env._hcg.get_parallel_mode( - ) in [ParallelMode.TENSOR_PARALLEL, ParallelMode.PIPELINE_PARALLEL]: - return _hp_recompute(function, *args) - else: - # when in pure data parallel or non-parallel training or strategy.recompute is False, use simple recompute. - return RecomputeFunction.apply(function, preserve, *args) + return RecomputeFunction.apply(function, preserve, *args) def recompute_sequential(functions, segments, input, **kwargs): diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py index 28740917c13f8..c6a564e2358df 100644 --- a/python/paddle/incubate/distributed/models/moe/moe_layer.py +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -34,7 +34,7 @@ from paddle.autograd import PyLayer from .gate import NaiveGate, GShardGate, SwitchGate, BaseGate from .utils import count_by_gate -from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute +from paddle.distributed import fleet from paddle import fluid from paddle.fluid.framework import in_dygraph_mode @@ -314,6 +314,8 @@ def __init__(self, super(MoELayer, self).__init__() recompute_interval = kwargs.get("recompute_interval", 0) + recompute_offload = kwargs.get("recompute_offload", False) + recompute_partition = kwargs.get("recompute_partition", False) if gate is None: gate = dict() @@ -328,6 +330,8 @@ def __init__(self, self.world_size = self.group.nranks self.num_expert = len(experts) self.recompute_interval = recompute_interval + self.recompute_offload = recompute_offload + self.recompute_partition = recompute_partition assert experts is not None self.experts = experts @@ -422,8 +426,9 @@ def experts_fwd(x, fwd_expert_count, experts): if self.recompute_interval <= 0 or x.shape[0] == 0: x = experts_fwd(x, fwd_expert_count.numpy(), self.experts) else: - x = _hp_recompute(experts_fwd, x, fwd_expert_count.numpy(), - self.experts) + x = fleet.hybrid_recompute(experts_fwd, self.recompute_offload, + self.recompute_partition, x, + fwd_expert_count.numpy(), self.experts) out_batch_size = inp.shape[0] if len(gate.shape) == 2: From da0b2a9460fdc0177b5eec6d1a219cda73bc1c58 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Wed, 24 Aug 2022 15:03:32 +0000 Subject: [PATCH 09/31] update. --- python/paddle/distributed/fleet/utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 34b911b90b596..4a6257af09ebe 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -15,7 +15,7 @@ from .fs import LocalFS # noqa: F401 from .fs import HDFSClient # noqa: F401 from .ps_util import DistributedInfer # noqa: F401 -from ..recompute import recompute, recompute_sequential # noqa: F401 +from ..recompute.recompute import recompute, recompute_sequential # noqa: F401 from . import log_util # noqa: F401 from . import hybrid_parallel_util # noqa: F401 From df8cc54152f4afa8395984117520949e415381ef Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Thu, 25 Aug 2022 01:41:31 +0000 Subject: [PATCH 10/31] update. --- python/paddle/distributed/fleet/utils/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 4a6257af09ebe..0020c84acdd79 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -15,6 +15,7 @@ from .fs import LocalFS # noqa: F401 from .fs import HDFSClient # noqa: F401 from .ps_util import DistributedInfer # noqa: F401 +from paddle.distributed import fleet from ..recompute.recompute import recompute, recompute_sequential # noqa: F401 from . import log_util # noqa: F401 From d73393fc64dd1e0b4777b1ee7341e61f425eba9c Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Thu, 25 Aug 2022 02:29:48 +0000 Subject: [PATCH 11/31] update test. --- .../unittests/dygraph_hybrid_recompute.py | 10 ++- .../tests/unittests/test_dygraph_recompute.py | 62 +++++++++++++++++-- .../test_dygraph_recompute_for_eager.py | 55 ++++++++++++++-- 3 files changed, 114 insertions(+), 13 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/dygraph_hybrid_recompute.py b/python/paddle/fluid/tests/unittests/dygraph_hybrid_recompute.py index 20196a98eb144..16e482139ff82 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_hybrid_recompute.py +++ b/python/paddle/fluid/tests/unittests/dygraph_hybrid_recompute.py @@ -70,9 +70,10 @@ def forward(self, pred, label): class RecomputeMatmulBlock(nn.Layer): - def __init__(self, mp, seed, m, n, k): + def __init__(self, mp, seed, m, n, k, use_fleet=False): super(RecomputeMatmulBlock, self).__init__() self.mp = mp + self.use_fleet = use_fleet if mp is not None and mp.nranks > 1: mp_linear_1 = fleet.meta_parallel.ColumnParallelLinear( m, @@ -99,7 +100,10 @@ def __init__(self, mp, seed, m, n, k): def forward(self, x): if self.mp: - return recompute(self.layers, x) + if self.use_fleet: + return fleet.recompute(self.layers, x) + else: + return recompute(self.layers, x) else: return self.layers(x) @@ -139,7 +143,7 @@ def __init__(self, hcg): self.layers_pp.append(dp_linear) mp = hcg.get_model_parallel_group() if hcg else None for i in range(6): - mp_layer = RecomputeBlock(mp, 1024 + i, 64, 128, 64) + mp_layer = RecomputeBlock(mp, 1024 + i, 64, 128, 64, True) act = nn.ReLU6() layer_seq = nn.Sequential(mp_layer, act) self.layers_pp.append(layer_seq) diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py b/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py index 11ca15fd33104..d42c3294805f2 100755 --- a/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py @@ -21,6 +21,7 @@ from paddle.autograd import PyLayer from paddle.distributed.fleet.utils import recompute import random +from paddle.distributed import fleet import paddle.fluid.layers as layers @@ -53,40 +54,62 @@ class Naive_fc_net(paddle.nn.Layer): def __init__(self, input_size=10, recompute_blocks=[1, 3], + use_fleet=False, + use_fleet_sq=False, + segments=1, recompute_kwargs={}): super(Naive_fc_net, self).__init__() self.recompute_blocks = recompute_blocks self.recompute_kwargs = recompute_kwargs + self.use_fleet = use_fleet + self.use_fleet_sq = use_fleet_sq + self.segments = segments + self.runfunc0 = get_fc_block(0, input_size, is_last=False) self.runfunc1 = get_fc_block(1, input_size, is_last=False) self.runfunc2 = get_fc_block(2, input_size, is_last=False) self.runfunc3 = get_fc_block(3, input_size, is_last=False) self.runfunc4 = get_fc_block(4, input_size, is_last=True) + if self.use_fleet_sq: + self.runfuncs = paddle.nn.Sequential(self.runfunc0, self.runfunc1, + self.runfunc2, self.runfunc3, + self.runfunc4) + def forward(self, inputs): + if self.use_fleet_sq: + return fleet.recompute_sequential(self.runfuncs, self.segments, + inputs) + if 0 in self.recompute_blocks: - inputs = recompute(self.runfunc0, inputs) + recompute_func = fleet.recompute if self.use_fleet else recompute + inputs = recompute_func(self.runfunc0, inputs) else: inputs = self.runfunc0(inputs) if 1 in self.recompute_blocks: - inputs = recompute(self.runfunc1, inputs) + recompute_func = fleet.recompute if self.use_fleet else recompute + inputs = recompute_func(self.runfunc1, inputs) else: inputs = self.runfunc1(inputs) if 2 in self.recompute_blocks: - inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs) + recompute_func = fleet.recompute if self.use_fleet else recompute + inputs = recompute_func(self.runfunc2, inputs, + **self.recompute_kwargs) else: inputs = self.runfunc2(inputs) if 3 in self.recompute_blocks: - inputs = recompute(self.runfunc3, inputs) + recompute_func = fleet.recompute if self.use_fleet else recompute + inputs = recompute_func(self.runfunc3, inputs) else: inputs = self.runfunc3(inputs) if 4 in self.recompute_blocks: - inputs = recompute(self.runfunc4, inputs) + recompute_func = fleet.recompute if self.use_fleet else recompute + inputs = recompute_func(self.runfunc4, inputs) else: inputs = self.runfunc4(inputs) @@ -179,6 +202,35 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + # recompute second block using fleet + loss, param, grad = run_model(recompute_block=[1], + use_fleet=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute second & fourth block using fleet + loss, param, grad = run_model(recompute_block=[1, 3], + use_fleet=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute using fleet.recompute_sequential, segments=1 + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute using fleet.recompute_sequential, segments=2 + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + segments=2, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_dropout(self): self.test_base_case() diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_recompute_for_eager.py b/python/paddle/fluid/tests/unittests/test_dygraph_recompute_for_eager.py index bc97d53485be9..d986a430c658c 100755 --- a/python/paddle/fluid/tests/unittests/test_dygraph_recompute_for_eager.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_recompute_for_eager.py @@ -24,6 +24,7 @@ import paddle from paddle.autograd import PyLayer from paddle.distributed.fleet.utils import recompute +from paddle.distributed import fleet import random import paddle.fluid.layers as layers @@ -57,40 +58,62 @@ class Naive_fc_net(paddle.nn.Layer): def __init__(self, input_size=10, recompute_blocks=[1, 3], + use_fleet=False, + use_fleet_sq=False, + segments=1, recompute_kwargs={}): super(Naive_fc_net, self).__init__() self.recompute_blocks = recompute_blocks self.recompute_kwargs = recompute_kwargs + self.use_fleet = use_fleet + self.use_fleet_sq = use_fleet_sq + self.segments = segments + self.runfunc0 = get_fc_block(0, input_size, is_last=False) self.runfunc1 = get_fc_block(1, input_size, is_last=False) self.runfunc2 = get_fc_block(2, input_size, is_last=False) self.runfunc3 = get_fc_block(3, input_size, is_last=False) self.runfunc4 = get_fc_block(4, input_size, is_last=True) + if self.use_fleet_sq: + self.runfuncs = paddle.nn.Sequential(self.runfunc0, self.runfunc1, + self.runfunc2, self.runfunc3, + self.runfunc4) + def forward(self, inputs): + if self.use_fleet_sq: + return fleet.recompute_sequential(self.runfuncs, self.segments, + inputs) + if 0 in self.recompute_blocks: - inputs = recompute(self.runfunc0, inputs) + recompute_func = fleet.recompute if self.use_fleet else recompute + inputs = recompute_func(self.runfunc0, inputs) else: inputs = self.runfunc0(inputs) if 1 in self.recompute_blocks: - inputs = recompute(self.runfunc1, inputs) + recompute_func = fleet.recompute if self.use_fleet else recompute + inputs = recompute_func(self.runfunc1, inputs) else: inputs = self.runfunc1(inputs) if 2 in self.recompute_blocks: - inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs) + recompute_func = fleet.recompute if self.use_fleet else recompute + inputs = recompute_func(self.runfunc2, inputs, + **self.recompute_kwargs) else: inputs = self.runfunc2(inputs) if 3 in self.recompute_blocks: - inputs = recompute(self.runfunc3, inputs) + recompute_func = fleet.recompute if self.use_fleet else recompute + inputs = recompute_func(self.runfunc3, inputs) else: inputs = self.runfunc3(inputs) if 4 in self.recompute_blocks: - inputs = recompute(self.runfunc4, inputs) + recompute_func = fleet.recompute if self.use_fleet else recompute + inputs = recompute_func(self.runfunc4, inputs) else: inputs = self.runfunc4(inputs) @@ -183,6 +206,28 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + # recompute second block using fleet + loss, param, grad = run_model(recompute_block=[1], + use_fleet=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute_sequential with segments=1 using fleet + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute_sequential with segments=2 using fleet + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + segments=2, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_dropout(self): self.test_base_case() From 47d2529829126b3fa3a7abbb04ae7f73409277f3 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Thu, 25 Aug 2022 03:09:47 +0000 Subject: [PATCH 12/31] update. --- .../distributed/fleet/utils/__init__.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 0020c84acdd79..5f6f903f37477 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -16,7 +16,7 @@ from .fs import HDFSClient # noqa: F401 from .ps_util import DistributedInfer # noqa: F401 from paddle.distributed import fleet -from ..recompute.recompute import recompute, recompute_sequential # noqa: F401 +import paddle.utils.deprecated as deprecated from . import log_util # noqa: F401 from . import hybrid_parallel_util # noqa: F401 @@ -25,3 +25,21 @@ "LocalFS", "recompute", "recompute_sequential", "DistributedInfer", "HDFSClient" ] + + +@deprecated( + since="2.4.0", + update_to="paddle.distributed.fleet.recompute_sequential", + level=1, + reason="Please use new recompute_sequential API(fleet.recompute_sequential) " +) +def recompute_sequential(functions, segments, input, **kwargs): + return fleet.recompute_sequential(functions, segments, input, **kwargs) + + +@deprecated(since="2.4.0", + update_to="paddle.distributed.fleet.recompute", + level=1, + reason="Please use new recompute API(fleet.recompute) ") +def recompute(function, *args, **kwargs): + return fleet.recompute(function, *args, **kwargs) From 02d08075315c51292bd6fc8d88526f077fa7abc3 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Thu, 25 Aug 2022 07:07:23 +0000 Subject: [PATCH 13/31] update test. --- .../distributed/fleet/recompute/recompute.py | 19 +++++++++---------- .../tests/unittests/test_dygraph_recompute.py | 14 ++++++-------- .../test_dygraph_recompute_for_eager.py | 14 ++++++-------- 3 files changed, 21 insertions(+), 26 deletions(-) diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index c0ff3179ab74b..931202fde5600 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -508,25 +508,24 @@ def recompute_sequential(functions, segments, input, **kwargs): raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) - def _run_func(begin, end, functions): + def _run_func(begin, end, funcs): def do_run(input): - for i in range(begin, end): - input = functions[i](input) + for i in range(begin, end + 1): + input = funcs[i](input) return input return do_run - assert isinstance(functions, paddle.nn.Sequential), ( - "The functions must be of type paddle.nn.Sequential.") + if isinstance(functions, paddle.nn.Sequential): + functions = list(functions.children()) - functions = list(functions.children()) segment_size = len(functions) // segments - end = 0 + end = -1 for begin in range(0, segment_size * (segments - 1), segment_size): - end = begin + segment_size - input = recompute(_run_func(begin, end), + end = begin + segment_size - 1 + input = recompute(_run_func(begin, end, functions), input, preserve_rng_state=preserve) - return _run_func(end, len(functions))(input) + return _run_func(end + 1, len(functions) - 1, functions)(input) diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py b/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py index d42c3294805f2..97bf592991e28 100755 --- a/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_recompute.py @@ -118,6 +118,9 @@ def forward(self, inputs): def run_model(recompute_block=[], recompute_kwargs={}, + use_fleet=False, + use_fleet_sq=False, + segments=1, enable_autocast=False, pure_fp16=False): gen = paddle.seed(10) @@ -128,6 +131,9 @@ def run_model(recompute_block=[], batch_size, input_size = 1, 10 model = Naive_fc_net(input_size, recompute_blocks=recompute_block, + use_fleet=use_fleet, + use_fleet_sq=use_fleet_sq, + segments=segments, recompute_kwargs=recompute_kwargs) loss_fn = paddle.nn.MSELoss(reduction='mean') optimizer = paddle.optimizer.SGD(learning_rate=0.01, @@ -223,14 +229,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - # recompute using fleet.recompute_sequential, segments=2 - loss, param, grad = run_model(recompute_block=[], - use_fleet_sq=True, - segments=2, - enable_autocast=enable_autocast, - pure_fp16=pure_fp16) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - def test_fc_net_with_dropout(self): self.test_base_case() diff --git a/python/paddle/fluid/tests/unittests/test_dygraph_recompute_for_eager.py b/python/paddle/fluid/tests/unittests/test_dygraph_recompute_for_eager.py index d986a430c658c..46aeb2e78314c 100755 --- a/python/paddle/fluid/tests/unittests/test_dygraph_recompute_for_eager.py +++ b/python/paddle/fluid/tests/unittests/test_dygraph_recompute_for_eager.py @@ -122,6 +122,9 @@ def forward(self, inputs): def run_model(recompute_block=[], recompute_kwargs={}, + use_fleet=False, + use_fleet_sq=False, + segments=1, enable_autocast=False, pure_fp16=False): gen = paddle.seed(10) @@ -132,6 +135,9 @@ def run_model(recompute_block=[], batch_size, input_size = 1, 10 model = Naive_fc_net(input_size, recompute_blocks=recompute_block, + use_fleet=use_fleet, + use_fleet_sq=use_fleet_sq, + segments=segments, recompute_kwargs=recompute_kwargs) loss_fn = paddle.nn.MSELoss(reduction='mean') optimizer = paddle.optimizer.SGD(learning_rate=0.01, @@ -220,14 +226,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - # recompute_sequential with segments=2 using fleet - loss, param, grad = run_model(recompute_block=[], - use_fleet_sq=True, - segments=2, - enable_autocast=enable_autocast, - pure_fp16=pure_fp16) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - def test_fc_net_with_dropout(self): self.test_base_case() From ec8061acdfc2832074ab88e52b5c3f31b4d770c4 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Thu, 25 Aug 2022 08:01:08 +0000 Subject: [PATCH 14/31] add package in setup.py.in --- python/setup.py.in | 1 + 1 file changed, 1 insertion(+) diff --git a/python/setup.py.in b/python/setup.py.in index 66f0575284d8d..84bdbaad464c9 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -294,6 +294,7 @@ packages=['paddle', 'paddle.distributed.launch.plugins', 'paddle.distributed.launch.utils', 'paddle.distributed.fleet.base', + 'paddle.distributed.fleet.recompute', 'paddle.distributed.fleet.elastic', 'paddle.distributed.fleet.meta_optimizers', 'paddle.distributed.fleet.meta_optimizers.sharding', From c578851ee10a423e3991955783f9d7d8e460ff0d Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Mon, 29 Aug 2022 01:16:47 +0000 Subject: [PATCH 15/31] update first. --- python/paddle/distributed/fleet/__init__.py | 8 ++--- .../parallel_layers/pp_layers.py | 2 +- .../meta_parallel/sharding/sharding_stage3.py | 2 +- .../distributed/fleet/recompute/__init__.py | 4 +-- .../distributed/fleet/recompute/recompute.py | 35 ++++++++----------- ...ybrid_recompute.py => recompute_hybrid.py} | 28 ++++++++++----- .../distributed/models/moe/moe_layer.py | 2 +- 7 files changed, 43 insertions(+), 38 deletions(-) rename python/paddle/distributed/fleet/recompute/{hybrid_recompute.py => recompute_hybrid.py} (91%) diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index 794198598f459..eeef71a96a7d6 100755 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -32,7 +32,7 @@ from .model import distributed_model from .optimizer import distributed_optimizer from .scaler import distributed_scaler -from .recompute import recompute, recompute_sequential, hybrid_recompute +from .recompute import recompute as R __all__ = [ #noqa "CommunicateTopology", "UtilBase", "HybridCommunicateGroup", @@ -91,6 +91,6 @@ shrink = fleet.shrink get_hybrid_communicate_group = fleet.get_hybrid_communicate_group distributed_scaler = distributed_scaler -recompute = recompute -recompute_sequential = recompute_sequential -hybrid_recompute = hybrid_recompute +recompute = R.recompute +recompute_sequential = R.recompute_sequential +recompute_hybrid = R.recompute_hybrid diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 67e6cf22a977e..37cde2fa502fb 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -401,7 +401,7 @@ def forward(self, input): input = (input, ) if self._need_recompute(funcs, input): - input = fleet.hybrid_recompute( + input = fleet.recompute_hybrid( self.forward_function(start_idx, end_idx), self._recompute_offload, self._recompute_partition, *input) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index 61b9e790474cd..c251a83306f4f 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -33,7 +33,7 @@ from paddle.distributed.collective import _get_global_group from .sharding_utils import Type, ShardingClipGrad, device_guard -from ...recompute.hybrid_recompute import _all_gather +from ...recompute.recompute_hybrid import _all_gather from ...utils.internal_storage import GradStorage # CUDA alignment 256 bytes diff --git a/python/paddle/distributed/fleet/recompute/__init__.py b/python/paddle/distributed/fleet/recompute/__init__.py index f9597f7a11f79..c5f99961e6f91 100644 --- a/python/paddle/distributed/fleet/recompute/__init__.py +++ b/python/paddle/distributed/fleet/recompute/__init__.py @@ -13,9 +13,9 @@ # limitations under the License. from .recompute import recompute, recompute_sequential, swith_rng_state_tracker, check_recompute_necessary, detach_variable -from .hybrid_recompute import hybrid_recompute +from .recompute_hybrid import recompute_hybrid __all__ = [ - "recompute", "recompute_sequential", "hybrid_recompute", + "recompute", "recompute_sequential", "recompute_hybrid", "swith_rng_state_tracker", "check_recompute_necessary", "detach_variable" ] diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index 931202fde5600..360d93484e738 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -210,12 +210,13 @@ def backward(ctx, *args): class RecomputeFunction(PyLayer): @staticmethod - def forward(ctx, run_function, preserve_rng_state, *args): + def forward(ctx, run_function, preserve_rng_state, *args, **kwargs): from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker # store for recomputing ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state + ctx.kwargs = kwargs # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input # the order of tensors in backward()'s output should be the same as tensors in forward()'s input @@ -300,7 +301,7 @@ def backward(ctx, *args): level=ctx.amp_level, dtype=ctx.amp_dtype): detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) + outputs = ctx.run_function(*detached_inputs, ctx.kwargs) else: with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, @@ -308,7 +309,7 @@ def backward(ctx, *args): level=ctx.amp_level, dtype=ctx.amp_dtype): detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) + outputs = ctx.run_function(*detached_inputs, ctx.kwargs) if isinstance(outputs, (core.VarBase, core.eager.Tensor)): outputs = (outputs, ) @@ -355,7 +356,7 @@ def recompute(function, *args, **kwargs): recompute intermediate activations to save then memory. Parameters: - function(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model + function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. *args(Tensor): inputs to the function. @@ -469,17 +470,14 @@ def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): """ # Hack to mix *args with **kwargs in a python 2.7-compliant way preserve = kwargs.pop('preserve_rng_state', True) - if kwargs: - raise ValueError("Unexpected keyword arguments: " + - ",".join(arg for arg in kwargs)) if framework._dygraph_tracer()._has_grad: check_recompute_necessary(args) - return RecomputeFunction.apply(function, preserve, *args) + return RecomputeFunction.apply(function, preserve, *args, **kwargs) -def recompute_sequential(functions, segments, input, **kwargs): +def recompute_sequential(functions, *args, **kwargs): """ recompute intermediate activations to save then memory for 'Sequential' models. @@ -487,12 +485,12 @@ def recompute_sequential(functions, segments, input, **kwargs): functions(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. - segments(int): Number of chunks to create in the model - input(Tensor): inputs to the function. - **kwargs(Dict): Kwargs should only contain the key-value pair of preserve_rng_state, which is used to + *args(Tensor): inputs to the function. + **kwargs(Dict): Kwargs should contain the key-value pair of preserve_rng_state, which is used to indicate whether to save the forward rng. If it is True, then the last forward rng value will be restored when the forward recalculation of backpropagation is performed. The default - preserve_rng_state is True. + preserve_rng_state is True. and it contains the key-value pair of __segments__, which is on behalf of the + Number of chunks to create in the model. Returns: Output of function on args. @@ -503,10 +501,7 @@ def recompute_sequential(functions, segments, input, **kwargs): model = paddle.nn.Sequential(...) input = recompute_sequential(model, segments, input) """ - preserve = kwargs.pop('preserve_rng_state', True) - if kwargs: - raise ValueError("Unexpected keyword arguments: " + - ",".join(arg for arg in kwargs)) + segments = kwargs.pop('__segments__', 1) def _run_func(begin, end, funcs): @@ -525,7 +520,5 @@ def do_run(input): end = -1 for begin in range(0, segment_size * (segments - 1), segment_size): end = begin + segment_size - 1 - input = recompute(_run_func(begin, end, functions), - input, - preserve_rng_state=preserve) - return _run_func(end + 1, len(functions) - 1, functions)(input) + args = recompute(_run_func(begin, end, functions), *args, **kwargs) + return _run_func(end + 1, len(functions) - 1, functions)(args) diff --git a/python/paddle/distributed/fleet/recompute/hybrid_recompute.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py similarity index 91% rename from python/paddle/distributed/fleet/recompute/hybrid_recompute.py rename to python/paddle/distributed/fleet/recompute/recompute_hybrid.py index 3ab796ba817ff..7af2f28de2dfb 100644 --- a/python/paddle/distributed/fleet/recompute/hybrid_recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -25,7 +25,7 @@ from .recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker __all__ = [ - "hybrid_recompute", "is_float_tensor", "get_tensor_dtype", + "recompute_hybrid", "is_float_tensor", "get_tensor_dtype", "paddle_2_number", "get_tensor_bytes" ] @@ -146,12 +146,15 @@ class _HPRecomputeFunction(PyLayer): """ @staticmethod - def forward(ctx, run_function, all_outputs, offload, partition, *args): + def forward(ctx, run_function, all_outputs, offload, partition, *args, + **kwargs): check_recompute_necessary(args) # store for recomputing ctx.run_function = run_function + ctx.kwargs = kwargs + # store the rng states ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state() ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( @@ -185,7 +188,7 @@ def forward(ctx, run_function, all_outputs, offload, partition, *args): ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() with paddle.no_grad(): - outputs = run_function(*args) + outputs = run_function(*args, **kwargs) for i, arg in enumerate(args): if paddle.is_tensor(arg): @@ -243,7 +246,7 @@ def backward(ctx, *args): custom_black_list=ctx.amp_black_list, level=ctx.amp_level): detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) + outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) if isinstance(outputs, (core.VarBase, core.eager.Tensor)): outputs = (outputs, ) @@ -272,7 +275,7 @@ def backward(ctx, *args): return grads -def hybrid_recompute(function, offload, partition, *args): +def recompute_hybrid(function, *args, **kwargs): """ # NODTE(shenliang03)The current hybrid parallel recompute has limitations. # It cannot handle the following situations: @@ -284,16 +287,25 @@ def hybrid_recompute(function, offload, partition, *args): function(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. - offload(bool): whether to offload checkpoint. - partition: whether to split activation into each rank in model group. *args(Tensor): inputs to the function. + + **kwargs(Dict): Kwargs should contain the key-value pair of preserve_rng_state, which is used to + indicate whether to save the forward rng. If it is True, then the last forward rng value will be + restored when the forward recalculation of backpropagation is performed. The default + preserve_rng_state is True. and it contains the key-value pair of __offload__ and __partition__, they are on behalf of whether to offload + to cpu and whether to split activation. + Returns: Output of function on args. """ + offload = kwargs.pop('__offload__', True) + partition = kwargs.pop('__partition__', True) + all_outputs = [] - _HPRecomputeFunction.apply(function, all_outputs, offload, partition, *args) + _HPRecomputeFunction.apply(function, all_outputs, offload, partition, *args, + **kwargs) if len(all_outputs) == 1: return all_outputs[0] diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py index c6a564e2358df..e8892ce9df6d2 100644 --- a/python/paddle/incubate/distributed/models/moe/moe_layer.py +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -426,7 +426,7 @@ def experts_fwd(x, fwd_expert_count, experts): if self.recompute_interval <= 0 or x.shape[0] == 0: x = experts_fwd(x, fwd_expert_count.numpy(), self.experts) else: - x = fleet.hybrid_recompute(experts_fwd, self.recompute_offload, + x = fleet.recompute_hybrid(experts_fwd, self.recompute_offload, self.recompute_partition, x, fwd_expert_count.numpy(), self.experts) From abd08527fa0fdde8c074759b0dfbd52f3048e34a Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Mon, 29 Aug 2022 01:21:16 +0000 Subject: [PATCH 16/31] update recompute_hybrid.py --- python/paddle/distributed/fleet/recompute/recompute_hybrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py index 1b452135f2106..8491cb6dc07d0 100644 --- a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -284,7 +284,7 @@ def recompute_hybrid(function, *args, **kwargs): # 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor Parameters: - function(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model + function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. *args(Tensor): inputs to the function. From 9e5788dce77010f7a2dc5eb2ad0b4564deef4a71 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Mon, 29 Aug 2022 03:25:17 +0000 Subject: [PATCH 17/31] update input of **kwargs. --- .../fleet/meta_parallel/parallel_layers/pp_layers.py | 7 ++++--- .../unittests/collective/fleet/test_dygraph_recompute.py | 5 +++-- .../collective/fleet/test_dygraph_recompute_for_eager.py | 5 +++-- .../paddle/incubate/distributed/models/moe/moe_layer.py | 9 ++++++--- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 0dd200e2e8a89..980062764cb40 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -557,9 +557,10 @@ def forward(self, input, chunk_id=None): if self._need_recompute(funcs, input): input = fleet.recompute_hybrid( - self.forward_function(start_idx, - end_idx), self._recompute_offload, - self._recompute_partition, *input) + self.forward_function(start_idx, end_idx), + *input, + __offload__=self._recompute_offload, + __partition__=self._recompute_partition) else: input = self.forward_function(start_idx, end_idx)(*input) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py index 97bf592991e28..ff8c80cf64913 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py @@ -79,8 +79,9 @@ def __init__(self, def forward(self, inputs): if self.use_fleet_sq: - return fleet.recompute_sequential(self.runfuncs, self.segments, - inputs) + return fleet.recompute_sequential(self.runfuncs, + inputs, + __segments__=self.segments) if 0 in self.recompute_blocks: recompute_func = fleet.recompute if self.use_fleet else recompute diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py index 46aeb2e78314c..f724749194b03 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py @@ -83,8 +83,9 @@ def __init__(self, def forward(self, inputs): if self.use_fleet_sq: - return fleet.recompute_sequential(self.runfuncs, self.segments, - inputs) + return fleet.recompute_sequential(self.runfuncs, + inputs, + __segments__=self.segments) if 0 in self.recompute_blocks: recompute_func = fleet.recompute if self.use_fleet else recompute diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py index ece7fc7548d81..0fa8859cd83ae 100644 --- a/python/paddle/incubate/distributed/models/moe/moe_layer.py +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -426,9 +426,12 @@ def experts_fwd(x, fwd_expert_count, experts): if self.recompute_interval <= 0 or x.shape[0] == 0: x = experts_fwd(x, fwd_expert_count.numpy(), self.experts) else: - x = fleet.recompute_hybrid(experts_fwd, self.recompute_offload, - self.recompute_partition, x, - fwd_expert_count.numpy(), self.experts) + x = fleet.recompute_hybrid(experts_fwd, + x, + fwd_expert_count.numpy(), + self.experts, + __offload__=self.recompute_offload, + __partition__=self.recompute_partition) out_batch_size = inp.shape[0] if len(gate.shape) == 2: From 153cb997acee0ca63dd3b6fa82e51fefd1c35d3e Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Mon, 29 Aug 2022 11:47:57 +0000 Subject: [PATCH 18/31] update. --- python/paddle/distributed/fleet/__init__.py | 8 +- .../parallel_layers/pp_layers.py | 10 +- .../pp_utils/p2p_communication.py | 4 +- .../fleet/meta_parallel/pp_utils/utils.py | 98 ++++++++++++ .../meta_parallel/sharding/sharding_stage3.py | 2 +- .../distributed/fleet/recompute/recompute.py | 33 +++-- .../fleet/recompute/recompute_hybrid.py | 140 +++++------------- .../distributed/fleet/utils/__init__.py | 13 +- .../fleet/test_dygraph_recompute.py | 7 +- .../fleet/test_dygraph_recompute_for_eager.py | 7 +- .../distributed/models/moe/moe_layer.py | 12 +- 11 files changed, 176 insertions(+), 158 deletions(-) create mode 100644 python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index eeef71a96a7d6..625913c647c0b 100755 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -32,7 +32,7 @@ from .model import distributed_model from .optimizer import distributed_optimizer from .scaler import distributed_scaler -from .recompute import recompute as R +import paddle.distributed.fleet.recompute as Re __all__ = [ #noqa "CommunicateTopology", "UtilBase", "HybridCommunicateGroup", @@ -91,6 +91,6 @@ shrink = fleet.shrink get_hybrid_communicate_group = fleet.get_hybrid_communicate_group distributed_scaler = distributed_scaler -recompute = R.recompute -recompute_sequential = R.recompute_sequential -recompute_hybrid = R.recompute_hybrid +recompute = Re.recompute +recompute_sequential = Re.recompute_sequential +recompute_hybrid = Re.recompute_hybrid diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 980062764cb40..af1e98be5c81d 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -557,10 +557,12 @@ def forward(self, input, chunk_id=None): if self._need_recompute(funcs, input): input = fleet.recompute_hybrid( - self.forward_function(start_idx, end_idx), - *input, - __offload__=self._recompute_offload, - __partition__=self._recompute_partition) + { + "mp_group": + fleet.fleet._hcg.get_model_parallel_group(), + "offload": self._recompute_offload, + "partition": self._recompute_partition + }, self.forward_function(start_idx, end_idx), *input) else: input = self.forward_function(start_idx, end_idx)(*input) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index a9720bb128134..6af444a12dd8c 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -18,6 +18,7 @@ from paddle import _C_ops, _legacy_C_ops import paddle.fluid.core as core from paddle.fluid.framework import _in_legacy_dygraph, _non_static_mode, in_dygraph_mode +from .utils import paddle_2_number, paddle_2_number, number_2_dtype _hcg = None _use_cache = False @@ -102,7 +103,6 @@ def recv_meta(self, group): self.recv_stop_gradient = tuple(stop_grads) def _send_dims_shape_dtype(self, tensor, group): - from ...recompute.hybrid_recompute import paddle_2_number # send len(shape) dims = paddle.to_tensor(len(tensor.shape)) dst_rank = group.ranks[1] @@ -143,7 +143,6 @@ def send_meta(self, tensor, group): self._send_dims_shape_dtype(d, group=group) def set_send_message(self, tensor): - from ...recompute.hybrid_recompute import paddle_2_number if isinstance(tensor, (paddle.Tensor, core.eager.Tensor)): self.send_shape_message = tensor.shape self.send_dtype_message = paddle_2_number(tensor.dtype) @@ -279,7 +278,6 @@ def allgather_partial(tensor, def _p2p_helper(tensor_send_next, tensor_send_prev, recv_prev, recv_next): - from ...recompute.hybrid_recompute import number_2_dtype global _hcg tensor_recv_prev = None diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py new file mode 100644 index 0000000000000..82c08e37e7e52 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -0,0 +1,98 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid import core +from paddle import _C_ops, _legacy_C_ops + +__all__ = [ + "is_float_tensor", "get_tensor_dtype", "paddle_2_number", "number_2_dtype", + "get_tensor_bytes", "_all_gather" +] + +FLOAT_TYPE_DICT = { + paddle.float16: "float16", + paddle.float32: "float32", + paddle.float64: "float64", +} + +PADDLE_TO_NUMBER = { + paddle.float16: 0, + paddle.float32: 1, + paddle.float64: 2, + paddle.int32: 3, + paddle.int64: 4 +} + +NUMBER_TO_DTYPE = { + 0: "float16", + 1: "float32", + 2: "float64", + 3: "int32", + 4: "int64" +} + + +def is_float_tensor(tensor): + """Is a float tensor""" + return tensor.dtype in FLOAT_TYPE_DICT.keys() + + +def get_tensor_dtype(dtype): + assert dtype in FLOAT_TYPE_DICT.keys() + return FLOAT_TYPE_DICT[dtype] + + +def paddle_2_number(dtype): + assert dtype in PADDLE_TO_NUMBER.keys() + return PADDLE_TO_NUMBER[dtype] + + +def number_2_dtype(number): + assert number in NUMBER_TO_DTYPE.keys() + return NUMBER_TO_DTYPE[number] + + +def get_tensor_bytes(tensor): + """Get the bytes a tensor occupied.""" + elem_size = None + if tensor.dtype == paddle.float32: + elem_size = 4 + elif tensor.dtype == paddle.float64: + elem_size = 8 + elif tensor.dtype == paddle.int64: + elem_size = 8 + elif tensor.dtype == paddle.int32: + elem_size = 4 + elif tensor.dtype == paddle.float16: + elem_size = 2 + elif tensor.dtype == paddle.int8: + elem_size = 1 + else: + raise ValueError("unknown data type: {}".format(tensor.dtype)) + return tensor.numel() * elem_size + + +def _all_gather(tensor, group=None, use_calc_stream=True): + """ + The main difference with paddle.distributed.all_gather: + no need to pass in tensor_list, the returned tensor is spliced + """ + if group is not None and not group.is_member(): + return + ring_id = 0 if group is None else group.id + nranks = paddle.distributed.collective._get_global_group( + ).nranks if group is None else group.nranks + return _legacy_C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream, + 'ring_id', ring_id, 'nranks', nranks) diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py index c251a83306f4f..67d48c8abba1b 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage3.py @@ -33,7 +33,7 @@ from paddle.distributed.collective import _get_global_group from .sharding_utils import Type, ShardingClipGrad, device_guard -from ...recompute.recompute_hybrid import _all_gather +from ..pp_utils.utils import _all_gather from ...utils.internal_storage import GradStorage # CUDA alignment 256 bytes diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index 360d93484e738..6474acdbf9c5a 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -269,7 +269,7 @@ def forward(ctx, run_function, preserve_rng_state, *args, **kwargs): ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() with paddle.no_grad(): - outputs = run_function(*args) + outputs = run_function(*args, **kwargs) return outputs @staticmethod @@ -301,7 +301,8 @@ def backward(ctx, *args): level=ctx.amp_level, dtype=ctx.amp_dtype): detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs, ctx.kwargs) + outputs = ctx.run_function(*detached_inputs, + **ctx.kwargs) else: with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, @@ -309,7 +310,7 @@ def backward(ctx, *args): level=ctx.amp_level, dtype=ctx.amp_dtype): detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs, ctx.kwargs) + outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) if isinstance(outputs, (core.VarBase, core.eager.Tensor)): outputs = (outputs, ) @@ -477,31 +478,32 @@ def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): return RecomputeFunction.apply(function, preserve, *args, **kwargs) -def recompute_sequential(functions, *args, **kwargs): +def recompute_sequential(ctx, functions, *args, **kwargs): """ recompute intermediate activations to save then memory for 'Sequential' models. Parameters: + ctx(dict): include 'segments' and 'preserve_rng_state' keys, the key 'segments' (int, default 1), represents the number of chunks to create in the model, + the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng. If it is True, then the last forward rng value will be + restored when the forward recalculation of backpropagation is performed. and some keys such as 'mp_group', 'offload' and 'partition' are invalid here, + they are useful in 'recompute_hybrid' API. functions(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. - *args(Tensor): inputs to the function. - **kwargs(Dict): Kwargs should contain the key-value pair of preserve_rng_state, which is used to - indicate whether to save the forward rng. If it is True, then the last forward rng value will be - restored when the forward recalculation of backpropagation is performed. The default - preserve_rng_state is True. and it contains the key-value pair of __segments__, which is on behalf of the - Number of chunks to create in the model. + *args(Tensor): inputs(tuple) to the function. + **kwargs(Dict): inputs(dict) to the function. Returns: - Output of function on args. + Output of function on args and kwargs. Examples: .. code-block:: python model = paddle.nn.Sequential(...) - input = recompute_sequential(model, segments, input) + input = recompute_sequential({'segments' : 1}, model, input) """ - segments = kwargs.pop('__segments__', 1) + segments = ctx.get('segments', 1) + preserve_rng_state = ctx.get('preserve_rng_state', True) def _run_func(begin, end, funcs): @@ -520,5 +522,8 @@ def do_run(input): end = -1 for begin in range(0, segment_size * (segments - 1), segment_size): end = begin + segment_size - 1 - args = recompute(_run_func(begin, end, functions), *args, **kwargs) + args = recompute(_run_func(begin, end, functions), + *args, + preserve_rng_state=preserve_rng_state, + **kwargs) return _run_func(end + 1, len(functions) - 1, functions)(args) diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py index 8491cb6dc07d0..58ee35657557b 100644 --- a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -23,93 +23,15 @@ from paddle.fluid.framework import in_dygraph_mode from paddle.distributed import fleet from .recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker +from ..meta_parallel.pp_utils import utils -__all__ = [ - "recompute_hybrid", "is_float_tensor", "get_tensor_dtype", - "paddle_2_number", "get_tensor_bytes" -] - -FLOAT_TYPE_DICT = { - paddle.float16: "float16", - paddle.float32: "float32", - paddle.float64: "float64", -} - -PADDLE_TO_NUMBER = { - paddle.float16: 0, - paddle.float32: 1, - paddle.float64: 2, - paddle.int32: 3, - paddle.int64: 4 -} - -NUMBER_TO_DTYPE = { - 0: "float16", - 1: "float32", - 2: "float64", - 3: "int32", - 4: "int64" -} - - -def is_float_tensor(tensor): - """Is a float tensor""" - return tensor.dtype in FLOAT_TYPE_DICT.keys() - - -def get_tensor_dtype(dtype): - assert dtype in FLOAT_TYPE_DICT.keys() - return FLOAT_TYPE_DICT[dtype] - - -def paddle_2_number(dtype): - assert dtype in PADDLE_TO_NUMBER.keys() - return PADDLE_TO_NUMBER[dtype] - - -def number_2_dtype(number): - assert number in NUMBER_TO_DTYPE.keys() - return NUMBER_TO_DTYPE[number] - - -def get_tensor_bytes(tensor): - """Get the bytes a tensor occupied.""" - elem_size = None - if tensor.dtype == paddle.float32: - elem_size = 4 - elif tensor.dtype == paddle.float64: - elem_size = 8 - elif tensor.dtype == paddle.int64: - elem_size = 8 - elif tensor.dtype == paddle.int32: - elem_size = 4 - elif tensor.dtype == paddle.float16: - elem_size = 2 - elif tensor.dtype == paddle.int8: - elem_size = 1 - else: - raise ValueError("unknown data type: {}".format(tensor.dtype)) - return tensor.numel() * elem_size - - -def _all_gather(tensor, group=None, use_calc_stream=True): - """ - The main difference with paddle.distributed.all_gather: - no need to pass in tensor_list, the returned tensor is spliced - """ - if group is not None and not group.is_member(): - return - ring_id = 0 if group is None else group.id - nranks = paddle.distributed.collective._get_global_group( - ).nranks if group is None else group.nranks - return _legacy_C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream, - 'ring_id', ring_id, 'nranks', nranks) +__all__ = ["recompute_hybrid"] -def _split_activation(tensor): +def _split_activation(tensor, mp_group): - mp_degree = fleet.fleet._hcg.get_model_parallel_world_size() - mp_rank = fleet.fleet._hcg.get_model_parallel_rank() + mp_degree = mp_group.nranks + mp_rank = mp_group.rank if mp_degree < 2: return tensor @@ -127,13 +49,12 @@ def _split_activation(tensor): return data[start:end] -def _merge_activation(tensor): - mp_degree = fleet.fleet._hcg.get_model_parallel_world_size() - mp_rank = fleet.fleet._hcg.get_model_parallel_rank() - mp_group = fleet.fleet._hcg.get_model_parallel_group() +def _merge_activation(tensor, mp_group): + mp_degree = mp_degree.nranks + mp_rank = mp_degree.rank if mp_degree < 2: return tensor - return _all_gather(tensor, group=mp_group) + return utils._all_gather(tensor, group=mp_group) class _HPRecomputeFunction(PyLayer): @@ -146,8 +67,8 @@ class _HPRecomputeFunction(PyLayer): """ @staticmethod - def forward(ctx, run_function, all_outputs, offload, partition, *args, - **kwargs): + def forward(ctx, run_function, all_outputs, mp_group, offload, partition, + preserve_rng_state, *args, **kwargs): check_recompute_necessary(args) # store for recomputing @@ -156,11 +77,13 @@ def forward(ctx, run_function, all_outputs, offload, partition, *args, ctx.kwargs = kwargs # store the rng states + assert preserve_rng_state, 'preserve_rng_state must be True in recompute_hybrid.' ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state() ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( ).get_states_tracker() # save config info + ctx.mp_group = mp_group ctx.offload = offload ctx.partition = partition @@ -195,7 +118,8 @@ def forward(ctx, run_function, all_outputs, offload, partition, *args, state = arg.stop_gradient if partition: ctx.tensor_shapes.append(arg.shape) - partition = _split_activation(arg.detach()).clone() + partition = _split_activation(arg.detach(), + mp_group).clone() # TODO(shenliang03) not use calculate stream to D2H to speed arg = partition.cpu() if offload else partition else: @@ -229,8 +153,9 @@ def backward(ctx, *args): for i, idx in enumerate(tensor_indices): if ctx.partition: state = tensors[i].stop_gradient - tensors[i] = _merge_activation( - tensors[i]).detach().reshape_(tensor_shapes[i]) + tensors[i] = _merge_activation(tensors[i], + mp_group).detach().reshape_( + tensor_shapes[i]) tensors[i].stop_gradient = state inputs[idx] = tensors[i].cuda( device_id) if ctx.offload else tensors[i] @@ -275,7 +200,7 @@ def backward(ctx, *args): return grads -def recompute_hybrid(function, *args, **kwargs): +def recompute_hybrid(ctx, function, *args, **kwargs): """ # NODTE(shenliang03)The current hybrid parallel recompute has limitations. # It cannot handle the following situations: @@ -284,34 +209,37 @@ def recompute_hybrid(function, *args, **kwargs): # 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor Parameters: + ctx(dict): include 'mp_group', 'offload', 'partition' and 'preserve_rng_state' keys. the key 'mp_group' (Group), represents the avtivations are splitted + in which group. the key 'offload' (bool, optional, default=False), represents whether to offload to cpu. the key 'partition' (bool, optional, default=False), + represents whether to split activations in the mp_group. the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng. + If it is True, then the last forward rng value will be restored when the forward recalculation of backpropagation is performed. and some keys such as 'segments', + are invalid here, they are useful in 'recompute_sequential' API. function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. - *args(Tensor): inputs to the function. + *args(Tensor): inputs(tuple) to the function. - **kwargs(Dict): Kwargs should contain the key-value pair of preserve_rng_state, which is used to - indicate whether to save the forward rng. If it is True, then the last forward rng value will be - restored when the forward recalculation of backpropagation is performed. The default - preserve_rng_state is True. and it contains the key-value pair of __offload__ and __partition__, they are on behalf of whether to offload - to cpu and whether to split activation. + **kwargs(Dict): inputs(dict) to the function. Returns: - Output of function on args. + Output of function on args and kwargs. """ + assert "mp_group" in ctx.keys(), "ctx must contains mp_group." - offload = kwargs.pop('__offload__', True) - partition = kwargs.pop('__partition__', True) + offload = ctx.get('offload', False) + partition = ctx.get('partition', False) + preserve_rng_state = ctx.get('preserve_rng_state', True) all_outputs = [] - _HPRecomputeFunction.apply(function, all_outputs, offload, partition, *args, - **kwargs) + _HPRecomputeFunction.apply(function, all_outputs, mp_group, offload, + partition, preserve_rng_state, *args, **kwargs) if len(all_outputs) == 1: return all_outputs[0] else: for output in all_outputs: - if paddle.is_tensor(output) and not is_float_tensor(output): + if paddle.is_tensor(output) and not utils.is_float_tensor(output): output.stop_gradient = True return tuple(all_outputs) diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 5f6f903f37477..db44048ac0042 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -22,21 +22,10 @@ from . import hybrid_parallel_util # noqa: F401 __all__ = [ #noqa - "LocalFS", "recompute", "recompute_sequential", "DistributedInfer", - "HDFSClient" + "LocalFS", "recompute", "DistributedInfer", "HDFSClient" ] -@deprecated( - since="2.4.0", - update_to="paddle.distributed.fleet.recompute_sequential", - level=1, - reason="Please use new recompute_sequential API(fleet.recompute_sequential) " -) -def recompute_sequential(functions, segments, input, **kwargs): - return fleet.recompute_sequential(functions, segments, input, **kwargs) - - @deprecated(since="2.4.0", update_to="paddle.distributed.fleet.recompute", level=1, diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py index ff8c80cf64913..db2d229e99b9e 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py @@ -79,9 +79,8 @@ def __init__(self, def forward(self, inputs): if self.use_fleet_sq: - return fleet.recompute_sequential(self.runfuncs, - inputs, - __segments__=self.segments) + return fleet.recompute_sequential({"segments": self.segments}, + self.runfuncs, inputs) if 0 in self.recompute_blocks: recompute_func = fleet.recompute if self.use_fleet else recompute @@ -242,7 +241,7 @@ def test_fc_net_with_fp16(self): def test_recompute_kwargs(self): paddle.set_device("gpu") kwargs = {"is_test": False} - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): loss_ref, param_ref, grad_ref = run_model(recompute_block=[2], recompute_kwargs=kwargs) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py index f724749194b03..b6da716b1a859 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py @@ -83,9 +83,8 @@ def __init__(self, def forward(self, inputs): if self.use_fleet_sq: - return fleet.recompute_sequential(self.runfuncs, - inputs, - __segments__=self.segments) + return fleet.recompute_sequential({"segments": self.segments}, + self.runfuncs, inputs) if 0 in self.recompute_blocks: recompute_func = fleet.recompute if self.use_fleet else recompute @@ -245,7 +244,7 @@ def test_fc_net_with_fp16(self): def test_recompute_kwargs(self): paddle.set_device("gpu") kwargs = {"is_test": False} - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): loss_ref, param_ref, grad_ref = run_model(recompute_block=[2], recompute_kwargs=kwargs) diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py index 0fa8859cd83ae..67909d609af91 100644 --- a/python/paddle/incubate/distributed/models/moe/moe_layer.py +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -426,12 +426,12 @@ def experts_fwd(x, fwd_expert_count, experts): if self.recompute_interval <= 0 or x.shape[0] == 0: x = experts_fwd(x, fwd_expert_count.numpy(), self.experts) else: - x = fleet.recompute_hybrid(experts_fwd, - x, - fwd_expert_count.numpy(), - self.experts, - __offload__=self.recompute_offload, - __partition__=self.recompute_partition) + x = fleet.recompute_hybrid( + { + "mp_group": fleet.fleet._hcg.get_model_parallel_group(), + "offload": self.recompute_offload, + "partition": self.recompute_partition + }, experts_fwd, x, fwd_expert_count.numpy(), self.experts) out_batch_size = inp.shape[0] if len(gate.shape) == 2: From 107b6a11022eeb5a9a136bffa19fe1154ca68e99 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Mon, 29 Aug 2022 12:06:36 +0000 Subject: [PATCH 19/31] update. --- python/paddle/distributed/fleet/recompute/__init__.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/fleet/recompute/__init__.py b/python/paddle/distributed/fleet/recompute/__init__.py index c5f99961e6f91..e5e732560a55d 100644 --- a/python/paddle/distributed/fleet/recompute/__init__.py +++ b/python/paddle/distributed/fleet/recompute/__init__.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .recompute import recompute, recompute_sequential, swith_rng_state_tracker, check_recompute_necessary, detach_variable +from .recompute import recompute, recompute_sequential from .recompute_hybrid import recompute_hybrid -__all__ = [ - "recompute", "recompute_sequential", "recompute_hybrid", - "swith_rng_state_tracker", "check_recompute_necessary", "detach_variable" -] +__all__ = ["recompute", "recompute_sequential", "recompute_hybrid"] From db278831d1c8bb0aa6f685f32245882cea9af48a Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Tue, 30 Aug 2022 02:02:44 +0000 Subject: [PATCH 20/31] update recompute_hybrid.py --- python/paddle/distributed/fleet/recompute/recompute_hybrid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py index 58ee35657557b..fa0f49c29c5f0 100644 --- a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -225,7 +225,8 @@ def recompute_hybrid(ctx, function, *args, **kwargs): Output of function on args and kwargs. """ - assert "mp_group" in ctx.keys(), "ctx must contains mp_group." + mp_group = ctx.get('mp_group', None) + assert mp_group is not None, "ctx must contains mp_group and mp_group can not be None." offload = ctx.get('offload', False) partition = ctx.get('partition', False) From 69e8ebb3815b8ec2e173a8cead91708ceef34f12 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Tue, 30 Aug 2022 11:06:54 +0000 Subject: [PATCH 21/31] update test. --- .../fleet/recompute/recompute_hybrid.py | 19 +- .../fleet/test_dygraph_recompute.py | 75 +++--- .../fleet/test_dygraph_recompute_for_eager.py | 75 +++--- .../unittests/dygraph_recompute_hybrid.py | 220 ++++++++++++++++++ .../tests/unittests/test_pipeline_parallel.py | 6 + 5 files changed, 325 insertions(+), 70 deletions(-) create mode 100755 python/paddle/fluid/tests/unittests/dygraph_recompute_hybrid.py diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py index fa0f49c29c5f0..8ddb20842f862 100644 --- a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -50,11 +50,18 @@ def _split_activation(tensor, mp_group): def _merge_activation(tensor, mp_group): - mp_degree = mp_degree.nranks - mp_rank = mp_degree.rank + mp_degree = mp_group.nranks + mp_rank = mp_group.rank if mp_degree < 2: return tensor - return utils._all_gather(tensor, group=mp_group) + + # adapt to new dygraph + tensor_shape = list(tensor.shape) + tensor_shape[0] *= mp_group.nranks + out = paddle.empty(tensor_shape, tensor.dtype) + task = mp_group.process_group.all_gather(tensor.cuda(), out) + task.wait() + return out class _HPRecomputeFunction(PyLayer): @@ -153,9 +160,9 @@ def backward(ctx, *args): for i, idx in enumerate(tensor_indices): if ctx.partition: state = tensors[i].stop_gradient - tensors[i] = _merge_activation(tensors[i], - mp_group).detach().reshape_( - tensor_shapes[i]) + tensors[i] = _merge_activation( + tensors[i], + ctx.mp_group).detach().reshape_(tensor_shapes[i]) tensors[i].stop_gradient = state inputs[idx] = tensors[i].cuda( device_id) if ctx.offload else tensors[i] diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py index db2d229e99b9e..f9f2909ea24be 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py @@ -19,7 +19,6 @@ import paddle from paddle.autograd import PyLayer -from paddle.distributed.fleet.utils import recompute import random from paddle.distributed import fleet @@ -57,12 +56,14 @@ def __init__(self, use_fleet=False, use_fleet_sq=False, segments=1, + use_raw_recompute=False, recompute_kwargs={}): super(Naive_fc_net, self).__init__() self.recompute_blocks = recompute_blocks self.recompute_kwargs = recompute_kwargs self.use_fleet = use_fleet self.use_fleet_sq = use_fleet_sq + self.use_raw_recompute = use_raw_recompute self.segments = segments self.runfunc0 = get_fc_block(0, input_size, is_last=False) @@ -71,47 +72,40 @@ def __init__(self, self.runfunc3 = get_fc_block(3, input_size, is_last=False) self.runfunc4 = get_fc_block(4, input_size, is_last=True) - if self.use_fleet_sq: + if self.use_fleet_sq and not use_raw_recompute: self.runfuncs = paddle.nn.Sequential(self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4) + self.layers = [ + self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, + self.runfunc4 + ] + + # default segments = 2 + if use_raw_recompute: + self.layers = [ + paddle.nn.Sequential(self.runfunc0, self.runfunc1), + paddle.nn.Sequential(self.runfunc2, self.runfunc3, + self.runfunc4) + ] + def forward(self, inputs): - if self.use_fleet_sq: + if self.use_fleet_sq and not self.use_raw_recompute: return fleet.recompute_sequential({"segments": self.segments}, self.runfuncs, inputs) - if 0 in self.recompute_blocks: - recompute_func = fleet.recompute if self.use_fleet else recompute - inputs = recompute_func(self.runfunc0, inputs) - else: - inputs = self.runfunc0(inputs) - - if 1 in self.recompute_blocks: - recompute_func = fleet.recompute if self.use_fleet else recompute - inputs = recompute_func(self.runfunc1, inputs) - else: - inputs = self.runfunc1(inputs) + if self.use_raw_recompute: + inputs = fleet.recompute(self.layers[0], inputs) + return self.layers[1](inputs) - if 2 in self.recompute_blocks: - recompute_func = fleet.recompute if self.use_fleet else recompute - inputs = recompute_func(self.runfunc2, inputs, - **self.recompute_kwargs) - else: - inputs = self.runfunc2(inputs) - - if 3 in self.recompute_blocks: - recompute_func = fleet.recompute if self.use_fleet else recompute - inputs = recompute_func(self.runfunc3, inputs) - else: - inputs = self.runfunc3(inputs) - - if 4 in self.recompute_blocks: - recompute_func = fleet.recompute if self.use_fleet else recompute - inputs = recompute_func(self.runfunc4, inputs) - else: - inputs = self.runfunc4(inputs) + for i in range(len(self.layers)): + if i in self.recompute_blocks: + inputs = fleet.recompute(self.layers[i], inputs, + **self.recompute_kwargs) + else: + inputs = self.layers[i](inputs) return inputs @@ -120,6 +114,7 @@ def run_model(recompute_block=[], recompute_kwargs={}, use_fleet=False, use_fleet_sq=False, + use_raw_recompute=False, segments=1, enable_autocast=False, pure_fp16=False): @@ -133,6 +128,7 @@ def run_model(recompute_block=[], recompute_blocks=recompute_block, use_fleet=use_fleet, use_fleet_sq=use_fleet_sq, + use_raw_recompute=use_raw_recompute, segments=segments, recompute_kwargs=recompute_kwargs) loss_fn = paddle.nn.MSELoss(reduction='mean') @@ -229,6 +225,21 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + # with base recompute, and segments=2 + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[], + enable_autocast=enable_autocast, + use_raw_recompute=True, + pure_fp16=pure_fp16) + + # recompute using fleet.recompute_sequential, segments=2 + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + segments=2, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_dropout(self): self.test_base_case() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py index b6da716b1a859..2ff6b4db58b20 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py @@ -23,7 +23,6 @@ import paddle from paddle.autograd import PyLayer -from paddle.distributed.fleet.utils import recompute from paddle.distributed import fleet import random @@ -61,12 +60,14 @@ def __init__(self, use_fleet=False, use_fleet_sq=False, segments=1, + use_raw_recompute=False, recompute_kwargs={}): super(Naive_fc_net, self).__init__() self.recompute_blocks = recompute_blocks self.recompute_kwargs = recompute_kwargs self.use_fleet = use_fleet self.use_fleet_sq = use_fleet_sq + self.use_raw_recompute = use_raw_recompute self.segments = segments self.runfunc0 = get_fc_block(0, input_size, is_last=False) @@ -75,47 +76,40 @@ def __init__(self, self.runfunc3 = get_fc_block(3, input_size, is_last=False) self.runfunc4 = get_fc_block(4, input_size, is_last=True) - if self.use_fleet_sq: + if self.use_fleet_sq and not use_raw_recompute: self.runfuncs = paddle.nn.Sequential(self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, self.runfunc4) + self.layers = [ + self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, + self.runfunc4 + ] + + # default segments = 2 + if use_raw_recompute: + self.layers = [ + paddle.nn.Sequential(self.runfunc0, self.runfunc1), + paddle.nn.Sequential(self.runfunc2, self.runfunc3, + self.runfunc4) + ] + def forward(self, inputs): - if self.use_fleet_sq: + if self.use_fleet_sq and not self.use_raw_recompute: return fleet.recompute_sequential({"segments": self.segments}, self.runfuncs, inputs) - if 0 in self.recompute_blocks: - recompute_func = fleet.recompute if self.use_fleet else recompute - inputs = recompute_func(self.runfunc0, inputs) - else: - inputs = self.runfunc0(inputs) - - if 1 in self.recompute_blocks: - recompute_func = fleet.recompute if self.use_fleet else recompute - inputs = recompute_func(self.runfunc1, inputs) - else: - inputs = self.runfunc1(inputs) + if self.use_raw_recompute: + inputs = fleet.recompute(self.layers[0], inputs) + return self.layers[1](inputs) - if 2 in self.recompute_blocks: - recompute_func = fleet.recompute if self.use_fleet else recompute - inputs = recompute_func(self.runfunc2, inputs, - **self.recompute_kwargs) - else: - inputs = self.runfunc2(inputs) - - if 3 in self.recompute_blocks: - recompute_func = fleet.recompute if self.use_fleet else recompute - inputs = recompute_func(self.runfunc3, inputs) - else: - inputs = self.runfunc3(inputs) - - if 4 in self.recompute_blocks: - recompute_func = fleet.recompute if self.use_fleet else recompute - inputs = recompute_func(self.runfunc4, inputs) - else: - inputs = self.runfunc4(inputs) + for i in range(len(self.layers)): + if i in self.recompute_blocks: + inputs = fleet.recompute(self.layers[i], inputs, + **self.recompute_kwargs) + else: + inputs = self.layers[i](inputs) return inputs @@ -124,6 +118,7 @@ def run_model(recompute_block=[], recompute_kwargs={}, use_fleet=False, use_fleet_sq=False, + use_raw_recompute=False, segments=1, enable_autocast=False, pure_fp16=False): @@ -137,6 +132,7 @@ def run_model(recompute_block=[], recompute_blocks=recompute_block, use_fleet=use_fleet, use_fleet_sq=use_fleet_sq, + use_raw_recompute=use_raw_recompute, segments=segments, recompute_kwargs=recompute_kwargs) loss_fn = paddle.nn.MSELoss(reduction='mean') @@ -226,6 +222,21 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + # with base recompute, and segments=2 + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[], + enable_autocast=enable_autocast, + use_raw_recompute=True, + pure_fp16=pure_fp16) + + # recompute using fleet.recompute_sequential, segments=2 + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + segments=2, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_dropout(self): self.test_base_case() diff --git a/python/paddle/fluid/tests/unittests/dygraph_recompute_hybrid.py b/python/paddle/fluid/tests/unittests/dygraph_recompute_hybrid.py new file mode 100755 index 0000000000000..03aa397a5a4c3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_recompute_hybrid.py @@ -0,0 +1,220 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle +from paddle.autograd import PyLayer +from paddle.distributed.fleet.utils import recompute +import random +from paddle.distributed import fleet + +import paddle.fluid.layers as layers + + +def get_fc_block(block_idx, input_size, is_last=False): + block_name = "block_" + str(block_idx) + block = paddle.nn.Sequential( + (block_name + "_fc_0", + paddle.nn.Linear(input_size, input_size, bias_attr=False)), + (block_name + "_dropout", paddle.nn.Dropout(p=0.5)), + (block_name + "_relu_1", paddle.nn.ReLU()), + (block_name + "_fc_1", + paddle.nn.Linear(input_size, input_size, bias_attr=False)), + (block_name + "_relu_2", paddle.nn.ReLU()), + ) + if is_last: + block.add_sublayer(block_name + "_fc_2", + paddle.nn.Linear(input_size, 1, + bias_attr=False)) # add sublayer + else: + block.add_sublayer(block_name + "_fc_2", + paddle.nn.Linear(input_size, + input_size, + bias_attr=False)) # add sublayer + return block + + +class Naive_fc_net(paddle.nn.Layer): + + def __init__(self, + input_size=10, + recompute_blocks=[1, 3], + offload=False, + partition=False, + recompute_kwargs={}): + super(Naive_fc_net, self).__init__() + self.recompute_blocks = recompute_blocks + self.recompute_kwargs = recompute_kwargs + self.offload = offload + self.partition = partition + + self.runfunc0 = get_fc_block(0, input_size, is_last=False) + self.runfunc1 = get_fc_block(1, input_size, is_last=False) + self.runfunc2 = get_fc_block(2, input_size, is_last=False) + self.runfunc3 = get_fc_block(3, input_size, is_last=False) + self.runfunc4 = get_fc_block(4, input_size, is_last=True) + + self.layers = [ + self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, + self.runfunc4 + ] + + def forward(self, inputs): + for i in range(len(self.layers)): + if i in self.recompute_blocks: + inputs = fleet.recompute_hybrid( + { + "mp_group": fleet.fleet._hcg.get_model_parallel_group(), + "offload": self.offload, + "partition": self.partition + }, self.layers[i], inputs, **self.recompute_kwargs) + else: + inputs = self.layers[i](inputs) + + return inputs + + +def run_model(recompute_block=[], + recompute_kwargs={}, + offload=False, + partition=False, + enable_autocast=False, + pure_fp16=False): + gen = paddle.seed(10) + gen.manual_seed(10) + np.random.seed(10) + random.seed(10) + + batch_size, input_size = 1, 10 + model = Naive_fc_net(input_size, + recompute_blocks=recompute_block, + offload=offload, + partition=partition, + recompute_kwargs=recompute_kwargs) + loss_fn = paddle.nn.MSELoss(reduction='mean') + optimizer = paddle.optimizer.SGD(learning_rate=0.01, + parameters=model.parameters()) + + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + + if enable_autocast: + scaler = paddle.amp.GradScaler() + scaler = fleet.distributed_scaler(scaler) + + loss_ = [] + param_ = [] + grad_ = [] + for step in range(10): + + x_data = np.random.randn(batch_size, input_size).astype(np.float32) + x = paddle.to_tensor(x_data) + # x.stop_gradient = False + level = 'O2' if pure_fp16 else 'O1' + with paddle.amp.auto_cast(True, level=level): + y_pred = model(x) + loss = y_pred.mean() + if enable_autocast: + scaler.scale(loss).backward() + scaler.minimize(optimizer, loss) + else: + loss_.append(np.asarray(loss).tolist()) + loss.backward() + optimizer.step() + + param_.append(np.asarray(model.parameters()[9]).tolist()) + grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist()) + + optimizer.clear_grad() + return loss_, param_, grad_ + + +class TestPyLayer(unittest.TestCase): + + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 1 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + fleet.init(is_collective=True, strategy=strategy) + + def test_base_case(self, enable_autocast=False, pure_fp16=False): + + def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): + self.assertEqual(loss_ref, loss) + self.assertEqual(param_ref, param) + self.assertEqual(grad_ref, grad) + + # without recompute + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + + # with recompute, offload=False, partition=False + loss, param, grad = run_model(recompute_block=[1, 3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with recompute, offload=True, partition=False + loss, param, grad = run_model(recompute_block=[1, 2, 3], + offload=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with recompute, offload=False, partition=True + loss, param, grad = run_model(recompute_block=[1], + partition=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with recompute, offload=True, partition=True + loss, param, grad = run_model(recompute_block=[1, 3, 4], + offload=True, + partition=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + def test_fc_net_with_dropout(self): + self.test_base_case() + + def test_fc_net_with_amp(self): + self.test_base_case(enable_autocast=True) + + def test_fc_net_with_fp16(self): + self.test_base_case(enable_autocast=True, pure_fp16=True) + + def test_recompute_kwargs(self): + paddle.set_device("gpu") + kwargs = {"is_test": False} + with self.assertRaises(TypeError): + loss_ref, param_ref, grad_ref = run_model(recompute_block=[2], + recompute_kwargs=kwargs) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py b/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py index 8773e8d47ed3c..10243a0faa944 100644 --- a/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py +++ b/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py @@ -26,5 +26,11 @@ def test_pipeline_parallel(self): self.run_mnist_2gpu('hybrid_parallel_pp_alexnet.py') +class TestModelParallelWithRecompute(TestMultipleGpus): + + def test_model_parallel_with_recompute(self): + self.run_mnist_2gpu("dygraph_recompute_hybrid.py") + + if __name__ == "__main__": unittest.main() From 307d8e74883947b35b2377489a63108d7215ccf0 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Tue, 30 Aug 2022 11:21:40 +0000 Subject: [PATCH 22/31] update. --- .../fleet/meta_parallel/parallel_layers/pp_layers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index af1e98be5c81d..1bbd3ce10e29b 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -558,8 +558,7 @@ def forward(self, input, chunk_id=None): if self._need_recompute(funcs, input): input = fleet.recompute_hybrid( { - "mp_group": - fleet.fleet._hcg.get_model_parallel_group(), + "mp_group": self._topo.get_model_parallel_group(), "offload": self._recompute_offload, "partition": self._recompute_partition }, self.forward_function(start_idx, end_idx), *input) From 52186110925aa325626f3d1768747ef5c89ddb37 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Tue, 30 Aug 2022 11:53:06 +0000 Subject: [PATCH 23/31] update. --- python/paddle/distributed/fleet/__init__.py | 8 ++++---- .../collective/multinode/dygraph_hybrid_recompute.py | 8 ++------ 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index 625913c647c0b..ddf280293f46a 100755 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -32,7 +32,7 @@ from .model import distributed_model from .optimizer import distributed_optimizer from .scaler import distributed_scaler -import paddle.distributed.fleet.recompute as Re +import paddle.distributed.fleet.recompute as rc __all__ = [ #noqa "CommunicateTopology", "UtilBase", "HybridCommunicateGroup", @@ -91,6 +91,6 @@ shrink = fleet.shrink get_hybrid_communicate_group = fleet.get_hybrid_communicate_group distributed_scaler = distributed_scaler -recompute = Re.recompute -recompute_sequential = Re.recompute_sequential -recompute_hybrid = Re.recompute_hybrid +recompute = rc.recompute +recompute_sequential = rc.recompute_sequential +recompute_hybrid = rc.recompute_hybrid diff --git a/python/paddle/fluid/tests/unittests/collective/multinode/dygraph_hybrid_recompute.py b/python/paddle/fluid/tests/unittests/collective/multinode/dygraph_hybrid_recompute.py index 16e482139ff82..28c9876ce10a5 100644 --- a/python/paddle/fluid/tests/unittests/collective/multinode/dygraph_hybrid_recompute.py +++ b/python/paddle/fluid/tests/unittests/collective/multinode/dygraph_hybrid_recompute.py @@ -70,10 +70,9 @@ def forward(self, pred, label): class RecomputeMatmulBlock(nn.Layer): - def __init__(self, mp, seed, m, n, k, use_fleet=False): + def __init__(self, mp, seed, m, n, k): super(RecomputeMatmulBlock, self).__init__() self.mp = mp - self.use_fleet = use_fleet if mp is not None and mp.nranks > 1: mp_linear_1 = fleet.meta_parallel.ColumnParallelLinear( m, @@ -100,10 +99,7 @@ def __init__(self, mp, seed, m, n, k, use_fleet=False): def forward(self, x): if self.mp: - if self.use_fleet: - return fleet.recompute(self.layers, x) - else: - return recompute(self.layers, x) + return fleet.recompute(self.layers, x) else: return self.layers(x) From ccd41446cebe2093a475de848c83bf4e31bbdccb Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Tue, 30 Aug 2022 12:04:00 +0000 Subject: [PATCH 24/31] update. --- .../collective/fleet/test_dygraph_recompute.py | 12 ------------ .../fleet/test_dygraph_recompute_for_eager.py | 11 ----------- 2 files changed, 23 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py index f9f2909ea24be..a7180aecc1f0e 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py @@ -53,7 +53,6 @@ class Naive_fc_net(paddle.nn.Layer): def __init__(self, input_size=10, recompute_blocks=[1, 3], - use_fleet=False, use_fleet_sq=False, segments=1, use_raw_recompute=False, @@ -61,7 +60,6 @@ def __init__(self, super(Naive_fc_net, self).__init__() self.recompute_blocks = recompute_blocks self.recompute_kwargs = recompute_kwargs - self.use_fleet = use_fleet self.use_fleet_sq = use_fleet_sq self.use_raw_recompute = use_raw_recompute self.segments = segments @@ -112,7 +110,6 @@ def forward(self, inputs): def run_model(recompute_block=[], recompute_kwargs={}, - use_fleet=False, use_fleet_sq=False, use_raw_recompute=False, segments=1, @@ -126,7 +123,6 @@ def run_model(recompute_block=[], batch_size, input_size = 1, 10 model = Naive_fc_net(input_size, recompute_blocks=recompute_block, - use_fleet=use_fleet, use_fleet_sq=use_fleet_sq, use_raw_recompute=use_raw_recompute, segments=segments, @@ -204,16 +200,8 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - # recompute second block using fleet - loss, param, grad = run_model(recompute_block=[1], - use_fleet=True, - enable_autocast=enable_autocast, - pure_fp16=pure_fp16) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - # recompute second & fourth block using fleet loss, param, grad = run_model(recompute_block=[1, 3], - use_fleet=True, enable_autocast=enable_autocast, pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py index 2ff6b4db58b20..5cddfa2eff8bf 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py @@ -57,7 +57,6 @@ class Naive_fc_net(paddle.nn.Layer): def __init__(self, input_size=10, recompute_blocks=[1, 3], - use_fleet=False, use_fleet_sq=False, segments=1, use_raw_recompute=False, @@ -65,7 +64,6 @@ def __init__(self, super(Naive_fc_net, self).__init__() self.recompute_blocks = recompute_blocks self.recompute_kwargs = recompute_kwargs - self.use_fleet = use_fleet self.use_fleet_sq = use_fleet_sq self.use_raw_recompute = use_raw_recompute self.segments = segments @@ -116,7 +114,6 @@ def forward(self, inputs): def run_model(recompute_block=[], recompute_kwargs={}, - use_fleet=False, use_fleet_sq=False, use_raw_recompute=False, segments=1, @@ -130,7 +127,6 @@ def run_model(recompute_block=[], batch_size, input_size = 1, 10 model = Naive_fc_net(input_size, recompute_blocks=recompute_block, - use_fleet=use_fleet, use_fleet_sq=use_fleet_sq, use_raw_recompute=use_raw_recompute, segments=segments, @@ -208,13 +204,6 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - # recompute second block using fleet - loss, param, grad = run_model(recompute_block=[1], - use_fleet=True, - enable_autocast=enable_autocast, - pure_fp16=pure_fp16) - check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) - # recompute_sequential with segments=1 using fleet loss, param, grad = run_model(recompute_block=[], use_fleet_sq=True, From ec2b973c03c3fdb85c1d3a7e5ad617b1fbf789c8 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Tue, 30 Aug 2022 12:07:10 +0000 Subject: [PATCH 25/31] update. --- .../unittests/collective/multinode/dygraph_hybrid_recompute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/collective/multinode/dygraph_hybrid_recompute.py b/python/paddle/fluid/tests/unittests/collective/multinode/dygraph_hybrid_recompute.py index 28c9876ce10a5..19eaeb462de64 100644 --- a/python/paddle/fluid/tests/unittests/collective/multinode/dygraph_hybrid_recompute.py +++ b/python/paddle/fluid/tests/unittests/collective/multinode/dygraph_hybrid_recompute.py @@ -139,7 +139,7 @@ def __init__(self, hcg): self.layers_pp.append(dp_linear) mp = hcg.get_model_parallel_group() if hcg else None for i in range(6): - mp_layer = RecomputeBlock(mp, 1024 + i, 64, 128, 64, True) + mp_layer = RecomputeBlock(mp, 1024 + i, 64, 128, 64) act = nn.ReLU6() layer_seq = nn.Sequential(mp_layer, act) self.layers_pp.append(layer_seq) From f261c51fa945fdc7f18fd6217ee4c7d1387f2a51 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Wed, 31 Aug 2022 01:40:11 +0000 Subject: [PATCH 26/31] update. --- .../fleet/meta_parallel/parallel_layers/pp_layers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 1bbd3ce10e29b..af1e98be5c81d 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -558,7 +558,8 @@ def forward(self, input, chunk_id=None): if self._need_recompute(funcs, input): input = fleet.recompute_hybrid( { - "mp_group": self._topo.get_model_parallel_group(), + "mp_group": + fleet.fleet._hcg.get_model_parallel_group(), "offload": self._recompute_offload, "partition": self._recompute_partition }, self.forward_function(start_idx, end_idx), *input) From a29cb3fbf3a7a08901c74ebe6ccca10447fa179b Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Wed, 31 Aug 2022 12:09:07 +0000 Subject: [PATCH 27/31] update note. --- python/paddle/distributed/fleet/recompute/recompute_hybrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py index 8ddb20842f862..14f1612fa05a4 100644 --- a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -37,7 +37,7 @@ def _split_activation(tensor, mp_group): tensor_numel = paddle.numel(tensor) assert tensor_numel != 0, "can't recompute zero element" - assert tensor_numel % mp_degree == 0, "The capacity of the activation () cannot be divisible by mp_degree()".format( + assert tensor_numel % mp_degree == 0, "The capacity of the activation ({}) cannot be divisible by mp_degree({})".format( tensor_numel, mp_degree) # use inplace operation to save memory From ab3876778c3015a67942d07f67e7ea8e7897ba5d Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Fri, 2 Sep 2022 05:26:38 +0000 Subject: [PATCH 28/31] update recompute_ctx. --- .../fleet/meta_parallel/parallel_layers/pp_layers.py | 10 +++------- .../incubate/distributed/models/moe/moe_layer.py | 10 ++-------- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 61467364777f5..ad5013c0c4817 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -309,13 +309,13 @@ def __init__(self, self._loss_fn = loss_fn self._topo = topology self._recompute_interval = recompute_interval + self.recompute_ctx = recompute_ctx if recompute_interval > 0: assert recompute_ctx is not None, "recompute_ctx must be not None for recompute." offload = recompute_ctx.get('offload', False) partition = recompute_ctx.get('partition', False) - _initialize_recompute_setting(offload, partition) logger.info( "Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}" .format(offload, partition)) @@ -636,12 +636,8 @@ def forward(self, input, chunk_id=None): if self._need_recompute(funcs, input): input = fleet.recompute_hybrid( - { - "mp_group": - fleet.fleet._hcg.get_model_parallel_group(), - "offload": self._recompute_offload, - "partition": self._recompute_partition - }, self.forward_function(start_idx, end_idx), *input) + self.recompute_ctx, + self.forward_function(start_idx, end_idx), *input) else: input = self.forward_function(start_idx, end_idx)(*input) diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py index d0ae5eb07ff8b..f3c139dcb36bc 100644 --- a/python/paddle/incubate/distributed/models/moe/moe_layer.py +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -330,8 +330,6 @@ def __init__(self, self.world_size = self.group.nranks self.num_expert = len(experts) self.recompute_interval = recompute_interval - self.recompute_offload = recompute_offload - self.recompute_partition = recompute_partition assert experts is not None self.experts = experts @@ -426,12 +424,8 @@ def experts_fwd(x, fwd_expert_count, experts): if self.recompute_interval <= 0 or x.shape[0] == 0: x = experts_fwd(x, fwd_expert_count.numpy(), self.experts) else: - x = fleet.recompute_hybrid( - { - "mp_group": fleet.fleet._hcg.get_model_parallel_group(), - "offload": self.recompute_offload, - "partition": self.recompute_partition - }, experts_fwd, x, fwd_expert_count.numpy(), self.experts) + x = fleet.recompute_hybrid(self.recompute_ctx, experts_fwd, x, + fwd_expert_count.numpy(), self.experts) out_batch_size = inp.shape[0] if len(gate.shape) == 2: From a99c046d9df933fee613447481d9627133f4d16c Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Tue, 13 Sep 2022 08:01:22 +0000 Subject: [PATCH 29/31] update __all__. --- .../distributed/fleet/meta_parallel/pp_utils/utils.py | 5 +---- python/paddle/distributed/fleet/recompute/__init__.py | 2 +- python/paddle/distributed/fleet/recompute/recompute.py | 5 +---- .../distributed/fleet/recompute/recompute_hybrid.py | 10 ++++------ 4 files changed, 7 insertions(+), 15 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index 82c08e37e7e52..c2008abb71c53 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -16,10 +16,7 @@ from paddle.fluid import core from paddle import _C_ops, _legacy_C_ops -__all__ = [ - "is_float_tensor", "get_tensor_dtype", "paddle_2_number", "number_2_dtype", - "get_tensor_bytes", "_all_gather" -] +__all__ = [] FLOAT_TYPE_DICT = { paddle.float16: "float16", diff --git a/python/paddle/distributed/fleet/recompute/__init__.py b/python/paddle/distributed/fleet/recompute/__init__.py index e5e732560a55d..7e5bcdb1db277 100644 --- a/python/paddle/distributed/fleet/recompute/__init__.py +++ b/python/paddle/distributed/fleet/recompute/__init__.py @@ -15,4 +15,4 @@ from .recompute import recompute, recompute_sequential from .recompute_hybrid import recompute_hybrid -__all__ = ["recompute", "recompute_sequential", "recompute_hybrid"] +__all__ = [] diff --git a/python/paddle/distributed/fleet/recompute/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py index 6474acdbf9c5a..73dddf740dc27 100755 --- a/python/paddle/distributed/fleet/recompute/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -30,10 +30,7 @@ ch.setFormatter(formatter) logger.addHandler(ch) -__all__ = [ - "recompute", "recompute_sequential", "swith_rng_state_tracker", - "check_recompute_necessary", "detach_variable" -] +__all__ = [] def detach_variable(inputs): diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py index 14f1612fa05a4..83a1958a9eeac 100644 --- a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -25,7 +25,7 @@ from .recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker from ..meta_parallel.pp_utils import utils -__all__ = ["recompute_hybrid"] +__all__ = [] def _split_activation(tensor, mp_group): @@ -216,11 +216,10 @@ def recompute_hybrid(ctx, function, *args, **kwargs): # 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor Parameters: - ctx(dict): include 'mp_group', 'offload', 'partition' and 'preserve_rng_state' keys. the key 'mp_group' (Group), represents the avtivations are splitted + ctx(dict): include 'mp_group', 'offload', and 'partition' keys. the key 'mp_group' (Group), represents the avtivations are splitted in which group. the key 'offload' (bool, optional, default=False), represents whether to offload to cpu. the key 'partition' (bool, optional, default=False), - represents whether to split activations in the mp_group. the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng. - If it is True, then the last forward rng value will be restored when the forward recalculation of backpropagation is performed. and some keys such as 'segments', - are invalid here, they are useful in 'recompute_sequential' API. + represents whether to split activations in the mp_group. and some keys such as 'segments' and 'preserve_rng_state' are invalid here, they are useful in + 'recompute_sequential' API. function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. @@ -237,7 +236,6 @@ def recompute_hybrid(ctx, function, *args, **kwargs): offload = ctx.get('offload', False) partition = ctx.get('partition', False) - preserve_rng_state = ctx.get('preserve_rng_state', True) all_outputs = [] _HPRecomputeFunction.apply(function, all_outputs, mp_group, offload, From 4af8ad5e0a8c80de198cf46292dffcb8e4a4d774 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Tue, 13 Sep 2022 12:51:04 +0000 Subject: [PATCH 30/31] update rng. --- .../paddle/distributed/fleet/recompute/recompute_hybrid.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py index 83a1958a9eeac..02faa8cd6a163 100644 --- a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -75,7 +75,7 @@ class _HPRecomputeFunction(PyLayer): @staticmethod def forward(ctx, run_function, all_outputs, mp_group, offload, partition, - preserve_rng_state, *args, **kwargs): + *args, **kwargs): check_recompute_necessary(args) # store for recomputing @@ -84,7 +84,6 @@ def forward(ctx, run_function, all_outputs, mp_group, offload, partition, ctx.kwargs = kwargs # store the rng states - assert preserve_rng_state, 'preserve_rng_state must be True in recompute_hybrid.' ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state() ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( ).get_states_tracker() @@ -239,7 +238,7 @@ def recompute_hybrid(ctx, function, *args, **kwargs): all_outputs = [] _HPRecomputeFunction.apply(function, all_outputs, mp_group, offload, - partition, preserve_rng_state, *args, **kwargs) + partition, *args, **kwargs) if len(all_outputs) == 1: return all_outputs[0] From 53b1fcce3f09cee4da8111f4d2f2c0c8864a7db8 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding Date: Thu, 15 Sep 2022 03:16:23 +0000 Subject: [PATCH 31/31] update. --- .../paddle/distributed/fleet/recompute/recompute_hybrid.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py index 02faa8cd6a163..4883cad2511bb 100644 --- a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -15,8 +15,8 @@ import contextlib import paddle -from paddle.fluid import core from paddle import _C_ops, _legacy_C_ops +from paddle.fluid import core from paddle.autograd import PyLayer from paddle.fluid import framework from ..meta_parallel.parallel_layers.random import get_rng_state_tracker @@ -217,14 +217,14 @@ def recompute_hybrid(ctx, function, *args, **kwargs): Parameters: ctx(dict): include 'mp_group', 'offload', and 'partition' keys. the key 'mp_group' (Group), represents the avtivations are splitted in which group. the key 'offload' (bool, optional, default=False), represents whether to offload to cpu. the key 'partition' (bool, optional, default=False), - represents whether to split activations in the mp_group. and some keys such as 'segments' and 'preserve_rng_state' are invalid here, they are useful in + represents whether to split activations in the mp_group. and some keys such as 'segments' and 'preserve_rng_state' are invalid here, they are useful in 'recompute_sequential' API. function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. *args(Tensor): inputs(tuple) to the function. - **kwargs(Dict): inputs(dict) to the function. + **kwargs(Dict): inputs(dict) to the function. Returns: Output of function on args and kwargs.