diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index 11d7643c676dd..ddf280293f46a 100755 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -32,6 +32,7 @@ from .model import distributed_model from .optimizer import distributed_optimizer from .scaler import distributed_scaler +import paddle.distributed.fleet.recompute as rc __all__ = [ #noqa "CommunicateTopology", "UtilBase", "HybridCommunicateGroup", @@ -90,3 +91,6 @@ shrink = fleet.shrink get_hybrid_communicate_group = fleet.get_hybrid_communicate_group distributed_scaler = distributed_scaler +recompute = rc.recompute +recompute_sequential = rc.recompute_sequential +recompute_hybrid = rc.recompute_hybrid diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index f6878ec1d8627..3b00961e78fbb 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -49,7 +49,7 @@ import paddle from paddle.fluid.dygraph.layers import Layer from ...utils.log_util import logger, layer_to_str -from ..pp_utils.utils import _hp_recompute, _initialize_recompute_setting +from paddle.distributed import fleet from paddle.fluid.framework import in_dygraph_mode __all__ = [] @@ -309,13 +309,13 @@ def __init__(self, self._loss_fn = loss_fn self._topo = topology self._recompute_interval = recompute_interval + self.recompute_ctx = recompute_ctx if recompute_interval > 0: assert recompute_ctx is not None, "recompute_ctx must be not None for recompute." offload = recompute_ctx.get('offload', False) partition = recompute_ctx.get('partition', False) - _initialize_recompute_setting(offload, partition) logger.info( "Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}" .format(offload, partition)) @@ -638,7 +638,8 @@ def forward(self, input, chunk_id=None): input = (input, ) if self._need_recompute(funcs, input): - input = _hp_recompute( + input = fleet.recompute_hybrid( + self.recompute_ctx, self.forward_function(start_idx, end_idx), *input) else: input = self.forward_function(start_idx, end_idx)(*input) diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 876f9ffaed32b..537885bfad349 100755 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -14,7 +14,6 @@ import paddle import paddle.fluid as fluid from .meta_parallel_base import MetaParallelBase -from .pp_utils.utils import is_float_tensor, _initialize_recompute_hcg from .parallel_layers.pp_layers import PipelineLayer from ..utils.hybrid_parallel_util import broadcast_mp_parameters @@ -61,8 +60,6 @@ def __init__(self, layers, hcg, strategy): p2p.initialize_p2p_groups(hcg, self._using_cache) - _initialize_recompute_hcg(hcg) - self.global_rank = self._hcg.get_global_rank() self.micro_batch_id = 0 diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py index 786eb20487a52..04575bfb23194 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/__init__.py @@ -12,6 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .utils import get_tensor_bytes - __all__ = [] diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py index ce5c1cfe9eb85..3b4094f047552 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/p2p_communication.py @@ -13,12 +13,12 @@ # limitations under the License. import paddle -from .utils import paddle_2_number, number_2_dtype from ...utils.log_util import logger import numpy as np from paddle import _C_ops, _legacy_C_ops import paddle.fluid.core as core from paddle.fluid.framework import _in_legacy_dygraph, _non_static_mode, in_dygraph_mode +from .utils import paddle_2_number, paddle_2_number, number_2_dtype _hcg = None _use_cache = False diff --git a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py index 8ec7f0f037b06..683cc51d27907 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py +++ b/python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py @@ -12,16 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib - import paddle from paddle.fluid import core from paddle import _C_ops, _legacy_C_ops -from paddle.autograd import PyLayer -from paddle.fluid import framework -from ...utils.recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker -from ..parallel_layers.random import get_rng_state_tracker -from paddle.fluid.framework import in_dygraph_mode __all__ = [] @@ -88,23 +81,6 @@ def get_tensor_bytes(tensor): return tensor.numel() * elem_size -_hcg = None -_recompute_offload = False -_recompute_partition = False - - -def _initialize_recompute_setting(is_offload, is_partition): - global _recompute_offload, _recompute_partition - - _recompute_offload = is_offload - _recompute_partition = is_partition - - -def _initialize_recompute_hcg(hcg): - global _hcg - _hcg = hcg - - def _all_gather(tensor, group=None, use_calc_stream=True): """ The main difference with paddle.distributed.all_gather: @@ -117,187 +93,3 @@ def _all_gather(tensor, group=None, use_calc_stream=True): ).nranks if group is None else group.nranks return _legacy_C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id, 'nranks', nranks) - - -def _split_activation(tensor): - global _hcg - - mp_degree = _hcg.get_model_parallel_world_size() - mp_rank = _hcg.get_model_parallel_rank() - if mp_degree < 2: - return tensor - - tensor_numel = paddle.numel(tensor) - assert tensor_numel != 0, "can't recompute zero element" - assert tensor_numel % mp_degree == 0, "The capacity of the activation () cannot be divisible by mp_degree()".format( - tensor_numel, mp_degree) - - # use inplace operation to save memory - data = tensor.flatten_() - - part_size = tensor_numel // mp_degree - start = part_size * mp_rank - end = start + part_size - return data[start:end] - - -def _merge_activation(tensor): - global _hcg - mp_degree = _hcg.get_model_parallel_world_size() - mp_rank = _hcg.get_model_parallel_rank() - mp_group = _hcg.get_model_parallel_group() - if mp_degree < 2: - return tensor - return _all_gather(tensor, group=mp_group) - - -class _HPRecomputeFunction(PyLayer): - """ - Compared with paddle.distributed.fleet.utils.recompute, there are the following differences: - 1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type. - 2. Offload support for activation - 3. Support MP segmentation of activation to further reduce cuda memory - 4. Adapt to the random state of MP - """ - - @staticmethod - def forward(ctx, run_function, all_outputs, *args): - check_recompute_necessary(args) - - # store for recomputing - ctx.run_function = run_function - - # store the rng states - ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state() - ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( - ).get_states_tracker() - - # save input for backward - ctx.inputs = [] - ctx.tensor_indices = [] - ctx.tensor_shapes = [] - tensor_inputs = [] - - cur_device = paddle.get_device() - assert 'gpu:' in paddle.get_device( - ), "Recompute with RNG is not support current device: {}.".format( - cur_device) - - # TODO support AMP - tracer = framework._dygraph_tracer() - ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True - if tracer._amp_level == core.AmpLevel.O2: - ctx.amp_level = 'O2' - elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): - ctx.amp_level = 'O1' - else: - raise ValueError("unsupported amp level: {}".format( - tracer._amp_level)) - ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() - - with paddle.no_grad(): - outputs = run_function(*args) - - for i, arg in enumerate(args): - if paddle.is_tensor(arg): - state = arg.stop_gradient - if _recompute_partition: - ctx.tensor_shapes.append(arg.shape) - partition = _split_activation(arg.detach()).clone() - # TODO(shenliang03) not use calculate stream to D2H to speed - arg = partition.cpu() if _recompute_offload else partition - else: - arg = arg.cpu() if _recompute_offload else arg - arg.stop_gradient = state - tensor_inputs.append(arg) - ctx.tensor_indices.append(i) - ctx.inputs.append(None) - else: - ctx.inputs.append(arg) - - ctx.save_for_backward(*tensor_inputs) - - if paddle.is_tensor(outputs): - all_outputs += [outputs] - return outputs - else: - all_outputs += outputs - return tuple(outputs) - - @staticmethod - def backward(ctx, *args): - with paddle.fluid.dygraph.guard(): - # Restore inputs - inputs = list(ctx.inputs) - tensor_indices = ctx.tensor_indices - tensor_shapes = ctx.tensor_shapes - tensors = list(ctx.saved_tensor()) - - device_id = paddle.distributed.ParallelEnv().device_id - for i, idx in enumerate(tensor_indices): - if _recompute_partition: - state = tensors[i].stop_gradient - tensors[i] = _merge_activation( - tensors[i]).detach().reshape_(tensor_shapes[i]) - tensors[i].stop_gradient = state - inputs[idx] = tensors[i].cuda( - device_id) if _recompute_offload else tensors[i] - - tracer = framework._dygraph_tracer() - tracer._has_grad = True - - # need restore auto_cast state as well as w/b list - with swith_rng_state_tracker(ctx.fwd_cuda_rng_state, - ctx.fwd_cuda_rng_state_tracker): - with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, - custom_white_list=ctx.amp_white_list, - custom_black_list=ctx.amp_black_list, - level=ctx.amp_level): - detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) - - if isinstance(outputs, (core.VarBase, core.eager.Tensor)): - outputs = (outputs, ) - assert len(outputs) == len(args) - - forward_outputs_with_grad = [] - backward_inputs = [] - - for i in range(len(outputs)): - if isinstance( - outputs[i], - (core.VarBase, - core.eager.Tensor)) and not outputs[i].stop_gradient: - forward_outputs_with_grad.append(outputs[i]) - backward_inputs.append(args[i]) - - if len(forward_outputs_with_grad) == 0: - raise RuntimeError( - "none of output has stop_gradient=False, this recompute() is not necessary" - ) - - # actually backward - paddle.autograd.backward(forward_outputs_with_grad, backward_inputs) - grads = tuple(inp._grad_ivar() for inp in detached_inputs - if isinstance(inp, (core.VarBase, core.eager.Tensor))) - return grads - - -def _hp_recompute(function, *args): - # NODTE(shenliang03)The current hybrid parallel recompute has limitations. - # It cannot handle the following situations: - # 1. The calculation output of recompute, there are tensors that do not require gradients. - # 2. The forward output tensor has no gradient. This problem can be solved temporarily by detach(). - # 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor - - all_outputs = [] - _HPRecomputeFunction.apply(function, all_outputs, *args) - - if len(all_outputs) == 1: - return all_outputs[0] - else: - for output in all_outputs: - if paddle.is_tensor(output) and not is_float_tensor(output): - output.stop_gradient = True - - return tuple(all_outputs) diff --git a/python/paddle/distributed/fleet/model.py b/python/paddle/distributed/fleet/model.py index fea2614fe84c3..40633788f12d4 100644 --- a/python/paddle/distributed/fleet/model.py +++ b/python/paddle/distributed/fleet/model.py @@ -20,46 +20,9 @@ from .meta_parallel import TensorParallel, model_parallel_random_seed from .meta_parallel import PipelineParallel, ShardingParallel, PipelineParallelWithInterleave, PipelineLayer from paddle.fluid import core -from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar from paddle.distributed import fleet - -class _RecomputeModelWrapper(paddle.nn.Layer): - - def __init__(self, model, segments=2, preserve_rng_state=True): - super(_RecomputeModelWrapper, self).__init__() - assert isinstance(model, paddle.nn.Sequential), ( - "The model passed to RecomputeModelWrapper must be of type " - "paddle.nn.Sequential.") - self._model = model - self._segments = segments - self._preserve_rng_state = preserve_rng_state - self._layers = list(model.children()) - self._segment_size = len(self._layers) // segments - - def _run_func(self, begin, end): - - def do_run(input): - for i in range(begin, end): - input = self._layers[i](input) - return input - - return do_run - - def _checkpoint(self, func, *args, **kwargs): - return LegacyRecomputeFunction.apply(func, self._preserve_rng_state, - *args) - - def forward(self, input): - end = 0 - for begin in range(0, self._segment_size * (self._segments - 1), - self._segment_size): - end = begin + self._segment_size - input = self._checkpoint(self._run_func(begin, end), input) - return self._run_func(end, len(self._layers))(input) - - _grad_scalar = None @@ -125,7 +88,6 @@ def forward(self, x): return model amp_enable = False - recompute_enable = False strategy = fleet_env._user_defined_strategy if strategy.amp == True: amp_enable = True @@ -154,10 +116,6 @@ def forward(self, x): decr_every_n_nan_or_inf=decr_every_n_nan_or_inf, use_dynamic_loss_scaling=use_dynamic_loss_scaling) - if strategy.recompute == True: - recompute_enable = True - model = _RecomputeModelWrapper(model) - if strategy.heter_ccl_mode == True: distributed_model = paddle.DataParallel( model, diff --git a/python/paddle/distributed/fleet/recompute/__init__.py b/python/paddle/distributed/fleet/recompute/__init__.py new file mode 100644 index 0000000000000..7e5bcdb1db277 --- /dev/null +++ b/python/paddle/distributed/fleet/recompute/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .recompute import recompute, recompute_sequential +from .recompute_hybrid import recompute_hybrid + +__all__ = [] diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/recompute/recompute.py similarity index 88% rename from python/paddle/distributed/fleet/utils/recompute.py rename to python/paddle/distributed/fleet/recompute/recompute.py index 2dddb1d9fb492..28ded25a0e6e0 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/recompute/recompute.py @@ -207,12 +207,13 @@ def backward(ctx, *args): class RecomputeFunction(PyLayer): @staticmethod - def forward(ctx, run_function, preserve_rng_state, *args): + def forward(ctx, run_function, preserve_rng_state, *args, **kwargs): from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker # store for recomputing ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state + ctx.kwargs = kwargs # NOTE the number of outputs of backward() should be equal to the number of tensors in forward()'s input # the order of tensors in backward()'s output should be the same as tensors in forward()'s input @@ -265,7 +266,7 @@ def forward(ctx, run_function, preserve_rng_state, *args): ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() with paddle.no_grad(): - outputs = run_function(*args) + outputs = run_function(*args, **kwargs) return outputs @staticmethod @@ -297,7 +298,8 @@ def backward(ctx, *args): level=ctx.amp_level, dtype=ctx.amp_dtype): detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) + outputs = ctx.run_function(*detached_inputs, + **ctx.kwargs) else: with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, custom_white_list=ctx.amp_white_list, @@ -305,7 +307,7 @@ def backward(ctx, *args): level=ctx.amp_level, dtype=ctx.amp_dtype): detached_inputs = detach_variable(tuple(inputs)) - outputs = ctx.run_function(*detached_inputs) + outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) if isinstance(outputs, (core.VarBase, core.eager.Tensor)): outputs = (outputs, ) @@ -352,7 +354,7 @@ def recompute(function, *args, **kwargs): recompute intermediate activations to save then memory. Parameters: - function(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model + function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model whose intermediate activations will be released to save memory in forward stage and will be recomputed in backward stage for gradient calculation. *args(Tensor): inputs to the function. @@ -466,11 +468,59 @@ def run_model(cuda_state, recompute_block=[], recompute_kwargs={}): """ # Hack to mix *args with **kwargs in a python 2.7-compliant way preserve = kwargs.pop('preserve_rng_state', True) - if kwargs: - raise ValueError("Unexpected keyword arguments: " + - ",".join(arg for arg in kwargs)) if framework._dygraph_tracer()._has_grad: check_recompute_necessary(args) - return RecomputeFunction.apply(function, preserve, *args) + return RecomputeFunction.apply(function, preserve, *args, **kwargs) + + +def recompute_sequential(ctx, functions, *args, **kwargs): + """ + recompute intermediate activations to save then memory for 'Sequential' models. + + Parameters: + ctx(dict): include 'segments' and 'preserve_rng_state' keys, the key 'segments' (int, default 1), represents the number of chunks to create in the model, + the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng. If it is True, then the last forward rng value will be + restored when the forward recalculation of backpropagation is performed. and some keys such as 'mp_group', 'offload' and 'partition' are invalid here, + they are useful in 'recompute_hybrid' API. + functions(paddle.nn.Sequential): layer of sequence of layers that describes part of forward pass of the model + whose intermediate activations will be released to save memory in forward stage and will be recomputed + in backward stage for gradient calculation. + *args(Tensor): inputs(tuple) to the function. + **kwargs(Dict): inputs(dict) to the function. + + Returns: + Output of function on args and kwargs. + + Examples: + .. code-block:: python + + model = paddle.nn.Sequential(...) + input = recompute_sequential({'segments' : 1}, model, input) + """ + segments = ctx.get('segments', 1) + preserve_rng_state = ctx.get('preserve_rng_state', True) + + def _run_func(begin, end, funcs): + + def do_run(input): + for i in range(begin, end + 1): + input = funcs[i](input) + return input + + return do_run + + if isinstance(functions, paddle.nn.Sequential): + functions = list(functions.children()) + + segment_size = len(functions) // segments + + end = -1 + for begin in range(0, segment_size * (segments - 1), segment_size): + end = begin + segment_size - 1 + args = recompute(_run_func(begin, end, functions), + *args, + preserve_rng_state=preserve_rng_state, + **kwargs) + return _run_func(end + 1, len(functions) - 1, functions)(args) diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py new file mode 100644 index 0000000000000..4883cad2511bb --- /dev/null +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -0,0 +1,250 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib + +import paddle +from paddle import _C_ops, _legacy_C_ops +from paddle.fluid import core +from paddle.autograd import PyLayer +from paddle.fluid import framework +from ..meta_parallel.parallel_layers.random import get_rng_state_tracker +from paddle.fluid.framework import in_dygraph_mode +from paddle.distributed import fleet +from .recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker +from ..meta_parallel.pp_utils import utils + +__all__ = [] + + +def _split_activation(tensor, mp_group): + + mp_degree = mp_group.nranks + mp_rank = mp_group.rank + if mp_degree < 2: + return tensor + + tensor_numel = paddle.numel(tensor) + assert tensor_numel != 0, "can't recompute zero element" + assert tensor_numel % mp_degree == 0, "The capacity of the activation ({}) cannot be divisible by mp_degree({})".format( + tensor_numel, mp_degree) + + # use inplace operation to save memory + data = tensor.flatten_() + + part_size = tensor_numel // mp_degree + start = part_size * mp_rank + end = start + part_size + return data[start:end] + + +def _merge_activation(tensor, mp_group): + mp_degree = mp_group.nranks + mp_rank = mp_group.rank + if mp_degree < 2: + return tensor + + # adapt to new dygraph + tensor_shape = list(tensor.shape) + tensor_shape[0] *= mp_group.nranks + out = paddle.empty(tensor_shape, tensor.dtype) + task = mp_group.process_group.all_gather(tensor.cuda(), out) + task.wait() + return out + + +class _HPRecomputeFunction(PyLayer): + """ + Compared with paddle.distributed.fleet.utils.recompute, there are the following differences: + 1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type. + 2. Offload support for activation + 3. Support MP segmentation of activation to further reduce cuda memory + 4. Adapt to the random state of MP + """ + + @staticmethod + def forward(ctx, run_function, all_outputs, mp_group, offload, partition, + *args, **kwargs): + check_recompute_necessary(args) + + # store for recomputing + ctx.run_function = run_function + + ctx.kwargs = kwargs + + # store the rng states + ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state() + ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker( + ).get_states_tracker() + + # save config info + ctx.mp_group = mp_group + ctx.offload = offload + ctx.partition = partition + + # save input for backward + ctx.inputs = [] + ctx.tensor_indices = [] + ctx.tensor_shapes = [] + tensor_inputs = [] + + cur_device = paddle.get_device() + assert 'gpu:' in paddle.get_device( + ), "Recompute with RNG is not support current device: {}.".format( + cur_device) + + # TODO support AMP + tracer = framework._dygraph_tracer() + ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True + if tracer._amp_level == core.AmpLevel.O2: + ctx.amp_level = 'O2' + elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0): + ctx.amp_level = 'O1' + else: + raise ValueError("unsupported amp level: {}".format( + tracer._amp_level)) + ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list() + + with paddle.no_grad(): + outputs = run_function(*args, **kwargs) + + for i, arg in enumerate(args): + if paddle.is_tensor(arg): + state = arg.stop_gradient + if partition: + ctx.tensor_shapes.append(arg.shape) + partition = _split_activation(arg.detach(), + mp_group).clone() + # TODO(shenliang03) not use calculate stream to D2H to speed + arg = partition.cpu() if offload else partition + else: + arg = arg.cpu() if offload else arg + arg.stop_gradient = state + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + ctx.save_for_backward(*tensor_inputs) + + if paddle.is_tensor(outputs): + all_outputs += [outputs] + return outputs + else: + all_outputs += outputs + return tuple(outputs) + + @staticmethod + def backward(ctx, *args): + with paddle.fluid.dygraph.guard(): + # Restore inputs + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensor_shapes = ctx.tensor_shapes + tensors = list(ctx.saved_tensor()) + + device_id = paddle.distributed.ParallelEnv().device_id + for i, idx in enumerate(tensor_indices): + if ctx.partition: + state = tensors[i].stop_gradient + tensors[i] = _merge_activation( + tensors[i], + ctx.mp_group).detach().reshape_(tensor_shapes[i]) + tensors[i].stop_gradient = state + inputs[idx] = tensors[i].cuda( + device_id) if ctx.offload else tensors[i] + + tracer = framework._dygraph_tracer() + tracer._has_grad = True + + # need restore auto_cast state as well as w/b list + with swith_rng_state_tracker(ctx.fwd_cuda_rng_state, + ctx.fwd_cuda_rng_state_tracker): + with paddle.amp.auto_cast(enable=ctx.is_fw_autocast, + custom_white_list=ctx.amp_white_list, + custom_black_list=ctx.amp_black_list, + level=ctx.amp_level): + detached_inputs = detach_variable(tuple(inputs)) + outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) + + if isinstance(outputs, (core.VarBase, core.eager.Tensor)): + outputs = (outputs, ) + assert len(outputs) == len(args) + + forward_outputs_with_grad = [] + backward_inputs = [] + + for i in range(len(outputs)): + if isinstance( + outputs[i], + (core.VarBase, + core.eager.Tensor)) and not outputs[i].stop_gradient: + forward_outputs_with_grad.append(outputs[i]) + backward_inputs.append(args[i]) + + if len(forward_outputs_with_grad) == 0: + raise RuntimeError( + "none of output has stop_gradient=False, this recompute() is not necessary" + ) + + # actually backward + paddle.autograd.backward(forward_outputs_with_grad, backward_inputs) + grads = tuple(inp._grad_ivar() for inp in detached_inputs + if isinstance(inp, (core.VarBase, core.eager.Tensor))) + return grads + + +def recompute_hybrid(ctx, function, *args, **kwargs): + """ + # NODTE(shenliang03)The current hybrid parallel recompute has limitations. + # It cannot handle the following situations: + # 1. The calculation output of recompute, there are tensors that do not require gradients. + # 2. The forward output tensor has no gradient. This problem can be solved temporarily by detach(). + # 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor + + Parameters: + ctx(dict): include 'mp_group', 'offload', and 'partition' keys. the key 'mp_group' (Group), represents the avtivations are splitted + in which group. the key 'offload' (bool, optional, default=False), represents whether to offload to cpu. the key 'partition' (bool, optional, default=False), + represents whether to split activations in the mp_group. and some keys such as 'segments' and 'preserve_rng_state' are invalid here, they are useful in + 'recompute_sequential' API. + function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model + whose intermediate activations will be released to save memory in forward stage and will be recomputed + in backward stage for gradient calculation. + *args(Tensor): inputs(tuple) to the function. + + **kwargs(Dict): inputs(dict) to the function. + + Returns: + Output of function on args and kwargs. + + """ + mp_group = ctx.get('mp_group', None) + assert mp_group is not None, "ctx must contains mp_group and mp_group can not be None." + + offload = ctx.get('offload', False) + partition = ctx.get('partition', False) + + all_outputs = [] + _HPRecomputeFunction.apply(function, all_outputs, mp_group, offload, + partition, *args, **kwargs) + + if len(all_outputs) == 1: + return all_outputs[0] + else: + for output in all_outputs: + if paddle.is_tensor(output) and not utils.is_float_tensor(output): + output.stop_gradient = True + + return tuple(all_outputs) diff --git a/python/paddle/distributed/fleet/utils/__init__.py b/python/paddle/distributed/fleet/utils/__init__.py index 1bf90a22e375c..db44048ac0042 100644 --- a/python/paddle/distributed/fleet/utils/__init__.py +++ b/python/paddle/distributed/fleet/utils/__init__.py @@ -15,7 +15,8 @@ from .fs import LocalFS # noqa: F401 from .fs import HDFSClient # noqa: F401 from .ps_util import DistributedInfer # noqa: F401 -from .recompute import recompute # noqa: F401 +from paddle.distributed import fleet +import paddle.utils.deprecated as deprecated from . import log_util # noqa: F401 from . import hybrid_parallel_util # noqa: F401 @@ -23,3 +24,11 @@ __all__ = [ #noqa "LocalFS", "recompute", "DistributedInfer", "HDFSClient" ] + + +@deprecated(since="2.4.0", + update_to="paddle.distributed.fleet.recompute", + level=1, + reason="Please use new recompute API(fleet.recompute) ") +def recompute(function, *args, **kwargs): + return fleet.recompute(function, *args, **kwargs) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py index 11ca15fd33104..a7180aecc1f0e 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute.py @@ -19,8 +19,8 @@ import paddle from paddle.autograd import PyLayer -from paddle.distributed.fleet.utils import recompute import random +from paddle.distributed import fleet import paddle.fluid.layers as layers @@ -53,48 +53,66 @@ class Naive_fc_net(paddle.nn.Layer): def __init__(self, input_size=10, recompute_blocks=[1, 3], + use_fleet_sq=False, + segments=1, + use_raw_recompute=False, recompute_kwargs={}): super(Naive_fc_net, self).__init__() self.recompute_blocks = recompute_blocks self.recompute_kwargs = recompute_kwargs + self.use_fleet_sq = use_fleet_sq + self.use_raw_recompute = use_raw_recompute + self.segments = segments + self.runfunc0 = get_fc_block(0, input_size, is_last=False) self.runfunc1 = get_fc_block(1, input_size, is_last=False) self.runfunc2 = get_fc_block(2, input_size, is_last=False) self.runfunc3 = get_fc_block(3, input_size, is_last=False) self.runfunc4 = get_fc_block(4, input_size, is_last=True) - def forward(self, inputs): + if self.use_fleet_sq and not use_raw_recompute: + self.runfuncs = paddle.nn.Sequential(self.runfunc0, self.runfunc1, + self.runfunc2, self.runfunc3, + self.runfunc4) - if 0 in self.recompute_blocks: - inputs = recompute(self.runfunc0, inputs) - else: - inputs = self.runfunc0(inputs) + self.layers = [ + self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, + self.runfunc4 + ] - if 1 in self.recompute_blocks: - inputs = recompute(self.runfunc1, inputs) - else: - inputs = self.runfunc1(inputs) + # default segments = 2 + if use_raw_recompute: + self.layers = [ + paddle.nn.Sequential(self.runfunc0, self.runfunc1), + paddle.nn.Sequential(self.runfunc2, self.runfunc3, + self.runfunc4) + ] - if 2 in self.recompute_blocks: - inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs) - else: - inputs = self.runfunc2(inputs) + def forward(self, inputs): - if 3 in self.recompute_blocks: - inputs = recompute(self.runfunc3, inputs) - else: - inputs = self.runfunc3(inputs) + if self.use_fleet_sq and not self.use_raw_recompute: + return fleet.recompute_sequential({"segments": self.segments}, + self.runfuncs, inputs) - if 4 in self.recompute_blocks: - inputs = recompute(self.runfunc4, inputs) - else: - inputs = self.runfunc4(inputs) + if self.use_raw_recompute: + inputs = fleet.recompute(self.layers[0], inputs) + return self.layers[1](inputs) + + for i in range(len(self.layers)): + if i in self.recompute_blocks: + inputs = fleet.recompute(self.layers[i], inputs, + **self.recompute_kwargs) + else: + inputs = self.layers[i](inputs) return inputs def run_model(recompute_block=[], recompute_kwargs={}, + use_fleet_sq=False, + use_raw_recompute=False, + segments=1, enable_autocast=False, pure_fp16=False): gen = paddle.seed(10) @@ -105,6 +123,9 @@ def run_model(recompute_block=[], batch_size, input_size = 1, 10 model = Naive_fc_net(input_size, recompute_blocks=recompute_block, + use_fleet_sq=use_fleet_sq, + use_raw_recompute=use_raw_recompute, + segments=segments, recompute_kwargs=recompute_kwargs) loss_fn = paddle.nn.MSELoss(reduction='mean') optimizer = paddle.optimizer.SGD(learning_rate=0.01, @@ -179,6 +200,34 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + # recompute second & fourth block using fleet + loss, param, grad = run_model(recompute_block=[1, 3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # recompute using fleet.recompute_sequential, segments=1 + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with base recompute, and segments=2 + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[], + enable_autocast=enable_autocast, + use_raw_recompute=True, + pure_fp16=pure_fp16) + + # recompute using fleet.recompute_sequential, segments=2 + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + segments=2, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_dropout(self): self.test_base_case() @@ -191,7 +240,7 @@ def test_fc_net_with_fp16(self): def test_recompute_kwargs(self): paddle.set_device("gpu") kwargs = {"is_test": False} - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): loss_ref, param_ref, grad_ref = run_model(recompute_block=[2], recompute_kwargs=kwargs) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py index bc97d53485be9..5cddfa2eff8bf 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_dygraph_recompute_for_eager.py @@ -23,7 +23,7 @@ import paddle from paddle.autograd import PyLayer -from paddle.distributed.fleet.utils import recompute +from paddle.distributed import fleet import random import paddle.fluid.layers as layers @@ -57,48 +57,66 @@ class Naive_fc_net(paddle.nn.Layer): def __init__(self, input_size=10, recompute_blocks=[1, 3], + use_fleet_sq=False, + segments=1, + use_raw_recompute=False, recompute_kwargs={}): super(Naive_fc_net, self).__init__() self.recompute_blocks = recompute_blocks self.recompute_kwargs = recompute_kwargs + self.use_fleet_sq = use_fleet_sq + self.use_raw_recompute = use_raw_recompute + self.segments = segments + self.runfunc0 = get_fc_block(0, input_size, is_last=False) self.runfunc1 = get_fc_block(1, input_size, is_last=False) self.runfunc2 = get_fc_block(2, input_size, is_last=False) self.runfunc3 = get_fc_block(3, input_size, is_last=False) self.runfunc4 = get_fc_block(4, input_size, is_last=True) - def forward(self, inputs): + if self.use_fleet_sq and not use_raw_recompute: + self.runfuncs = paddle.nn.Sequential(self.runfunc0, self.runfunc1, + self.runfunc2, self.runfunc3, + self.runfunc4) - if 0 in self.recompute_blocks: - inputs = recompute(self.runfunc0, inputs) - else: - inputs = self.runfunc0(inputs) + self.layers = [ + self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, + self.runfunc4 + ] - if 1 in self.recompute_blocks: - inputs = recompute(self.runfunc1, inputs) - else: - inputs = self.runfunc1(inputs) + # default segments = 2 + if use_raw_recompute: + self.layers = [ + paddle.nn.Sequential(self.runfunc0, self.runfunc1), + paddle.nn.Sequential(self.runfunc2, self.runfunc3, + self.runfunc4) + ] - if 2 in self.recompute_blocks: - inputs = recompute(self.runfunc2, inputs, **self.recompute_kwargs) - else: - inputs = self.runfunc2(inputs) + def forward(self, inputs): - if 3 in self.recompute_blocks: - inputs = recompute(self.runfunc3, inputs) - else: - inputs = self.runfunc3(inputs) + if self.use_fleet_sq and not self.use_raw_recompute: + return fleet.recompute_sequential({"segments": self.segments}, + self.runfuncs, inputs) - if 4 in self.recompute_blocks: - inputs = recompute(self.runfunc4, inputs) - else: - inputs = self.runfunc4(inputs) + if self.use_raw_recompute: + inputs = fleet.recompute(self.layers[0], inputs) + return self.layers[1](inputs) + + for i in range(len(self.layers)): + if i in self.recompute_blocks: + inputs = fleet.recompute(self.layers[i], inputs, + **self.recompute_kwargs) + else: + inputs = self.layers[i](inputs) return inputs def run_model(recompute_block=[], recompute_kwargs={}, + use_fleet_sq=False, + use_raw_recompute=False, + segments=1, enable_autocast=False, pure_fp16=False): gen = paddle.seed(10) @@ -109,6 +127,9 @@ def run_model(recompute_block=[], batch_size, input_size = 1, 10 model = Naive_fc_net(input_size, recompute_blocks=recompute_block, + use_fleet_sq=use_fleet_sq, + use_raw_recompute=use_raw_recompute, + segments=segments, recompute_kwargs=recompute_kwargs) loss_fn = paddle.nn.MSELoss(reduction='mean') optimizer = paddle.optimizer.SGD(learning_rate=0.01, @@ -183,6 +204,28 @@ def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): pure_fp16=pure_fp16) check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + # recompute_sequential with segments=1 using fleet + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with base recompute, and segments=2 + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[], + enable_autocast=enable_autocast, + use_raw_recompute=True, + pure_fp16=pure_fp16) + + # recompute using fleet.recompute_sequential, segments=2 + loss, param, grad = run_model(recompute_block=[], + use_fleet_sq=True, + segments=2, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + def test_fc_net_with_dropout(self): self.test_base_case() @@ -201,7 +244,7 @@ def test_fc_net_with_fp16(self): def test_recompute_kwargs(self): paddle.set_device("gpu") kwargs = {"is_test": False} - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): loss_ref, param_ref, grad_ref = run_model(recompute_block=[2], recompute_kwargs=kwargs) diff --git a/python/paddle/fluid/tests/unittests/collective/multinode/dygraph_hybrid_recompute.py b/python/paddle/fluid/tests/unittests/collective/multinode/dygraph_hybrid_recompute.py index 20196a98eb144..19eaeb462de64 100644 --- a/python/paddle/fluid/tests/unittests/collective/multinode/dygraph_hybrid_recompute.py +++ b/python/paddle/fluid/tests/unittests/collective/multinode/dygraph_hybrid_recompute.py @@ -99,7 +99,7 @@ def __init__(self, mp, seed, m, n, k): def forward(self, x): if self.mp: - return recompute(self.layers, x) + return fleet.recompute(self.layers, x) else: return self.layers(x) diff --git a/python/paddle/fluid/tests/unittests/dygraph_recompute_hybrid.py b/python/paddle/fluid/tests/unittests/dygraph_recompute_hybrid.py new file mode 100755 index 0000000000000..03aa397a5a4c3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_recompute_hybrid.py @@ -0,0 +1,220 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle +from paddle.autograd import PyLayer +from paddle.distributed.fleet.utils import recompute +import random +from paddle.distributed import fleet + +import paddle.fluid.layers as layers + + +def get_fc_block(block_idx, input_size, is_last=False): + block_name = "block_" + str(block_idx) + block = paddle.nn.Sequential( + (block_name + "_fc_0", + paddle.nn.Linear(input_size, input_size, bias_attr=False)), + (block_name + "_dropout", paddle.nn.Dropout(p=0.5)), + (block_name + "_relu_1", paddle.nn.ReLU()), + (block_name + "_fc_1", + paddle.nn.Linear(input_size, input_size, bias_attr=False)), + (block_name + "_relu_2", paddle.nn.ReLU()), + ) + if is_last: + block.add_sublayer(block_name + "_fc_2", + paddle.nn.Linear(input_size, 1, + bias_attr=False)) # add sublayer + else: + block.add_sublayer(block_name + "_fc_2", + paddle.nn.Linear(input_size, + input_size, + bias_attr=False)) # add sublayer + return block + + +class Naive_fc_net(paddle.nn.Layer): + + def __init__(self, + input_size=10, + recompute_blocks=[1, 3], + offload=False, + partition=False, + recompute_kwargs={}): + super(Naive_fc_net, self).__init__() + self.recompute_blocks = recompute_blocks + self.recompute_kwargs = recompute_kwargs + self.offload = offload + self.partition = partition + + self.runfunc0 = get_fc_block(0, input_size, is_last=False) + self.runfunc1 = get_fc_block(1, input_size, is_last=False) + self.runfunc2 = get_fc_block(2, input_size, is_last=False) + self.runfunc3 = get_fc_block(3, input_size, is_last=False) + self.runfunc4 = get_fc_block(4, input_size, is_last=True) + + self.layers = [ + self.runfunc0, self.runfunc1, self.runfunc2, self.runfunc3, + self.runfunc4 + ] + + def forward(self, inputs): + for i in range(len(self.layers)): + if i in self.recompute_blocks: + inputs = fleet.recompute_hybrid( + { + "mp_group": fleet.fleet._hcg.get_model_parallel_group(), + "offload": self.offload, + "partition": self.partition + }, self.layers[i], inputs, **self.recompute_kwargs) + else: + inputs = self.layers[i](inputs) + + return inputs + + +def run_model(recompute_block=[], + recompute_kwargs={}, + offload=False, + partition=False, + enable_autocast=False, + pure_fp16=False): + gen = paddle.seed(10) + gen.manual_seed(10) + np.random.seed(10) + random.seed(10) + + batch_size, input_size = 1, 10 + model = Naive_fc_net(input_size, + recompute_blocks=recompute_block, + offload=offload, + partition=partition, + recompute_kwargs=recompute_kwargs) + loss_fn = paddle.nn.MSELoss(reduction='mean') + optimizer = paddle.optimizer.SGD(learning_rate=0.01, + parameters=model.parameters()) + + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + + if enable_autocast: + scaler = paddle.amp.GradScaler() + scaler = fleet.distributed_scaler(scaler) + + loss_ = [] + param_ = [] + grad_ = [] + for step in range(10): + + x_data = np.random.randn(batch_size, input_size).astype(np.float32) + x = paddle.to_tensor(x_data) + # x.stop_gradient = False + level = 'O2' if pure_fp16 else 'O1' + with paddle.amp.auto_cast(True, level=level): + y_pred = model(x) + loss = y_pred.mean() + if enable_autocast: + scaler.scale(loss).backward() + scaler.minimize(optimizer, loss) + else: + loss_.append(np.asarray(loss).tolist()) + loss.backward() + optimizer.step() + + param_.append(np.asarray(model.parameters()[9]).tolist()) + grad_.append(np.asarray(model.parameters()[3]._grad_ivar()).tolist()) + + optimizer.clear_grad() + return loss_, param_, grad_ + + +class TestPyLayer(unittest.TestCase): + + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 2 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 1 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + fleet.init(is_collective=True, strategy=strategy) + + def test_base_case(self, enable_autocast=False, pure_fp16=False): + + def check_identical(loss_ref, param_ref, grad_ref, loss, param, grad): + self.assertEqual(loss_ref, loss) + self.assertEqual(param_ref, param) + self.assertEqual(grad_ref, grad) + + # without recompute + loss_ref, param_ref, grad_ref = run_model( + recompute_block=[], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + + # with recompute, offload=False, partition=False + loss, param, grad = run_model(recompute_block=[1, 3], + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with recompute, offload=True, partition=False + loss, param, grad = run_model(recompute_block=[1, 2, 3], + offload=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with recompute, offload=False, partition=True + loss, param, grad = run_model(recompute_block=[1], + partition=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + # with recompute, offload=True, partition=True + loss, param, grad = run_model(recompute_block=[1, 3, 4], + offload=True, + partition=True, + enable_autocast=enable_autocast, + pure_fp16=pure_fp16) + check_identical(loss_ref, param_ref, grad_ref, loss, param, grad) + + def test_fc_net_with_dropout(self): + self.test_base_case() + + def test_fc_net_with_amp(self): + self.test_base_case(enable_autocast=True) + + def test_fc_net_with_fp16(self): + self.test_base_case(enable_autocast=True, pure_fp16=True) + + def test_recompute_kwargs(self): + paddle.set_device("gpu") + kwargs = {"is_test": False} + with self.assertRaises(TypeError): + loss_ref, param_ref, grad_ref = run_model(recompute_block=[2], + recompute_kwargs=kwargs) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py b/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py index 8773e8d47ed3c..10243a0faa944 100644 --- a/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py +++ b/python/paddle/fluid/tests/unittests/test_pipeline_parallel.py @@ -26,5 +26,11 @@ def test_pipeline_parallel(self): self.run_mnist_2gpu('hybrid_parallel_pp_alexnet.py') +class TestModelParallelWithRecompute(TestMultipleGpus): + + def test_model_parallel_with_recompute(self): + self.run_mnist_2gpu("dygraph_recompute_hybrid.py") + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py index f25b00cb4beef..a8729306759d0 100644 --- a/python/paddle/incubate/distributed/models/moe/moe_layer.py +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -34,7 +34,7 @@ from paddle.autograd import PyLayer from .gate import NaiveGate, GShardGate, SwitchGate, BaseGate from .utils import count_by_gate -from paddle.distributed.fleet.meta_parallel.pp_utils.utils import _hp_recompute +from paddle.distributed import fleet from paddle import fluid from paddle.fluid.framework import in_dygraph_mode @@ -424,8 +424,8 @@ def experts_fwd(x, fwd_expert_count, experts): if self.recompute_interval <= 0 or x.shape[0] == 0: x = experts_fwd(x, fwd_expert_count.numpy(), self.experts) else: - x = _hp_recompute(experts_fwd, x, fwd_expert_count.numpy(), - self.experts) + x = fleet.recompute_hybrid(self.recompute_ctx, experts_fwd, x, + fwd_expert_count.numpy(), self.experts) out_batch_size = inp.shape[0] if len(gate.shape) == 2: diff --git a/python/setup.py.in b/python/setup.py.in index 3d400881de382..a8d5b7f23705e 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -296,6 +296,7 @@ packages=['paddle', 'paddle.distributed.launch.plugins', 'paddle.distributed.launch.utils', 'paddle.distributed.fleet.base', + 'paddle.distributed.fleet.recompute', 'paddle.distributed.fleet.elastic', 'paddle.distributed.fleet.meta_optimizers', 'paddle.distributed.fleet.meta_optimizers.sharding',