Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor recompute #45348

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
362baa3
add recompute_sequential
wuhuachaocoding Aug 23, 2022
49d2463
add recompute dirs.
wuhuachaocoding Aug 23, 2022
d2b2519
update cite.
wuhuachaocoding Aug 23, 2022
b15ce9f
update.
wuhuachaocoding Aug 23, 2022
18462ac
recompute unify.
wuhuachaocoding Aug 23, 2022
91bb00f
update recompute.
wuhuachaocoding Aug 23, 2022
a249601
update recompute.
wuhuachaocoding Aug 24, 2022
1239180
refact recompute.
wuhuachaocoding Aug 24, 2022
da0b2a9
update.
wuhuachaocoding Aug 24, 2022
df8cc54
update.
wuhuachaocoding Aug 25, 2022
d73393f
update test.
wuhuachaocoding Aug 25, 2022
47d2529
update.
wuhuachaocoding Aug 25, 2022
02d0807
update test.
wuhuachaocoding Aug 25, 2022
ec8061a
add package in setup.py.in
wuhuachaocoding Aug 25, 2022
c578851
update first.
wuhuachaocoding Aug 29, 2022
2dd523d
Merge remote-tracking branch 'upstream/develop' into recompute_unify
wuhuachaocoding Aug 29, 2022
abd0852
update recompute_hybrid.py
wuhuachaocoding Aug 29, 2022
9e5788d
update input of **kwargs.
wuhuachaocoding Aug 29, 2022
153cb99
update.
wuhuachaocoding Aug 29, 2022
107b6a1
update.
wuhuachaocoding Aug 29, 2022
db27883
update recompute_hybrid.py
wuhuachaocoding Aug 30, 2022
69e8ebb
update test.
wuhuachaocoding Aug 30, 2022
307d8e7
update.
wuhuachaocoding Aug 30, 2022
5218611
update.
wuhuachaocoding Aug 30, 2022
ccd4144
update.
wuhuachaocoding Aug 30, 2022
ec2b973
update.
wuhuachaocoding Aug 30, 2022
f261c51
update.
wuhuachaocoding Aug 31, 2022
a29cb3f
update note.
wuhuachaocoding Aug 31, 2022
f038f14
Merge remote-tracking branch 'upstream/develop' into recompute_unify
wuhuachaocoding Sep 2, 2022
ab38767
update recompute_ctx.
wuhuachaocoding Sep 2, 2022
a99c046
update __all__.
wuhuachaocoding Sep 13, 2022
b54e2be
Merge remote-tracking branch 'upstream/develop' into recompute_unify
wuhuachaocoding Sep 13, 2022
4af8ad5
update rng.
wuhuachaocoding Sep 13, 2022
1d615a5
Merge remote-tracking branch 'upstream/develop' into recompute_unify
wuhuachaocoding Sep 15, 2022
53b1fcc
update.
wuhuachaocoding Sep 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/paddle/distributed/fleet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .model import distributed_model
from .optimizer import distributed_optimizer
from .scaler import distributed_scaler
import paddle.distributed.fleet.recompute as rc

__all__ = [ #noqa
"CommunicateTopology", "UtilBase", "HybridCommunicateGroup",
Expand Down Expand Up @@ -90,3 +91,6 @@
shrink = fleet.shrink
get_hybrid_communicate_group = fleet.get_hybrid_communicate_group
distributed_scaler = distributed_scaler
recompute = rc.recompute
recompute_sequential = rc.recompute_sequential
recompute_hybrid = rc.recompute_hybrid
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
import paddle
from paddle.fluid.dygraph.layers import Layer
from ...utils.log_util import logger, layer_to_str
from ..pp_utils.utils import _hp_recompute, _initialize_recompute_setting
from paddle.distributed import fleet
from paddle.fluid.framework import in_dygraph_mode

__all__ = []
Expand Down Expand Up @@ -309,13 +309,13 @@ def __init__(self,
self._loss_fn = loss_fn
self._topo = topology
self._recompute_interval = recompute_interval
self.recompute_ctx = recompute_ctx

if recompute_interval > 0:
assert recompute_ctx is not None, "recompute_ctx must be not None for recompute."

offload = recompute_ctx.get('offload', False)
partition = recompute_ctx.get('partition', False)
_initialize_recompute_setting(offload, partition)
logger.info(
"Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}"
.format(offload, partition))
Expand Down Expand Up @@ -638,7 +638,8 @@ def forward(self, input, chunk_id=None):
input = (input, )

if self._need_recompute(funcs, input):
input = _hp_recompute(
input = fleet.recompute_hybrid(
self.recompute_ctx,
self.forward_function(start_idx, end_idx), *input)
else:
input = self.forward_function(start_idx, end_idx)(*input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
from .pp_utils.utils import is_float_tensor, _initialize_recompute_hcg
from .parallel_layers.pp_layers import PipelineLayer

from ..utils.hybrid_parallel_util import broadcast_mp_parameters
Expand Down Expand Up @@ -61,8 +60,6 @@ def __init__(self, layers, hcg, strategy):

p2p.initialize_p2p_groups(hcg, self._using_cache)

_initialize_recompute_hcg(hcg)

self.global_rank = self._hcg.get_global_rank()
self.micro_batch_id = 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .utils import get_tensor_bytes

__all__ = []
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

import paddle
from .utils import paddle_2_number, number_2_dtype
from ...utils.log_util import logger
import numpy as np
from paddle import _C_ops, _legacy_C_ops
import paddle.fluid.core as core
from paddle.fluid.framework import _in_legacy_dygraph, _non_static_mode, in_dygraph_mode
from .utils import paddle_2_number, paddle_2_number, number_2_dtype

_hcg = None
_use_cache = False
Expand Down
208 changes: 0 additions & 208 deletions python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib

import paddle
from paddle.fluid import core
from paddle import _C_ops, _legacy_C_ops
from paddle.autograd import PyLayer
from paddle.fluid import framework
from ...utils.recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker
from ..parallel_layers.random import get_rng_state_tracker
from paddle.fluid.framework import in_dygraph_mode

__all__ = []

Expand Down Expand Up @@ -88,23 +81,6 @@ def get_tensor_bytes(tensor):
return tensor.numel() * elem_size


_hcg = None
_recompute_offload = False
_recompute_partition = False


def _initialize_recompute_setting(is_offload, is_partition):
global _recompute_offload, _recompute_partition

_recompute_offload = is_offload
_recompute_partition = is_partition


def _initialize_recompute_hcg(hcg):
global _hcg
_hcg = hcg


def _all_gather(tensor, group=None, use_calc_stream=True):
"""
The main difference with paddle.distributed.all_gather:
Expand All @@ -117,187 +93,3 @@ def _all_gather(tensor, group=None, use_calc_stream=True):
).nranks if group is None else group.nranks
return _legacy_C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'nranks', nranks)


def _split_activation(tensor):
global _hcg

mp_degree = _hcg.get_model_parallel_world_size()
mp_rank = _hcg.get_model_parallel_rank()
if mp_degree < 2:
return tensor

tensor_numel = paddle.numel(tensor)
assert tensor_numel != 0, "can't recompute zero element"
assert tensor_numel % mp_degree == 0, "The capacity of the activation () cannot be divisible by mp_degree()".format(
tensor_numel, mp_degree)

# use inplace operation to save memory
data = tensor.flatten_()

part_size = tensor_numel // mp_degree
start = part_size * mp_rank
end = start + part_size
return data[start:end]


def _merge_activation(tensor):
global _hcg
mp_degree = _hcg.get_model_parallel_world_size()
mp_rank = _hcg.get_model_parallel_rank()
mp_group = _hcg.get_model_parallel_group()
if mp_degree < 2:
return tensor
return _all_gather(tensor, group=mp_group)


class _HPRecomputeFunction(PyLayer):
"""
Compared with paddle.distributed.fleet.utils.recompute, there are the following differences:
1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type.
2. Offload support for activation
3. Support MP segmentation of activation to further reduce cuda memory
4. Adapt to the random state of MP
"""

@staticmethod
def forward(ctx, run_function, all_outputs, *args):
check_recompute_necessary(args)

# store for recomputing
ctx.run_function = run_function

# store the rng states
ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker(
).get_states_tracker()

# save input for backward
ctx.inputs = []
ctx.tensor_indices = []
ctx.tensor_shapes = []
tensor_inputs = []

cur_device = paddle.get_device()
assert 'gpu:' in paddle.get_device(
), "Recompute with RNG is not support current device: {}.".format(
cur_device)

# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
if tracer._amp_level == core.AmpLevel.O2:
ctx.amp_level = 'O2'
elif tracer._amp_level in (core.AmpLevel.O1, core.AmpLevel.O0):
ctx.amp_level = 'O1'
else:
raise ValueError("unsupported amp level: {}".format(
tracer._amp_level))
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()

with paddle.no_grad():
outputs = run_function(*args)

for i, arg in enumerate(args):
if paddle.is_tensor(arg):
state = arg.stop_gradient
if _recompute_partition:
ctx.tensor_shapes.append(arg.shape)
partition = _split_activation(arg.detach()).clone()
# TODO(shenliang03) not use calculate stream to D2H to speed
arg = partition.cpu() if _recompute_offload else partition
else:
arg = arg.cpu() if _recompute_offload else arg
arg.stop_gradient = state
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)

ctx.save_for_backward(*tensor_inputs)

if paddle.is_tensor(outputs):
all_outputs += [outputs]
return outputs
else:
all_outputs += outputs
return tuple(outputs)

@staticmethod
def backward(ctx, *args):
with paddle.fluid.dygraph.guard():
# Restore inputs
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensor_shapes = ctx.tensor_shapes
tensors = list(ctx.saved_tensor())

device_id = paddle.distributed.ParallelEnv().device_id
for i, idx in enumerate(tensor_indices):
if _recompute_partition:
state = tensors[i].stop_gradient
tensors[i] = _merge_activation(
tensors[i]).detach().reshape_(tensor_shapes[i])
tensors[i].stop_gradient = state
inputs[idx] = tensors[i].cuda(
device_id) if _recompute_offload else tensors[i]

tracer = framework._dygraph_tracer()
tracer._has_grad = True

# need restore auto_cast state as well as w/b list
with swith_rng_state_tracker(ctx.fwd_cuda_rng_state,
ctx.fwd_cuda_rng_state_tracker):
with paddle.amp.auto_cast(enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list,
level=ctx.amp_level):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, (core.VarBase, core.eager.Tensor)):
outputs = (outputs, )
assert len(outputs) == len(args)

forward_outputs_with_grad = []
backward_inputs = []

for i in range(len(outputs)):
if isinstance(
outputs[i],
(core.VarBase,
core.eager.Tensor)) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i])
backward_inputs.append(args[i])

if len(forward_outputs_with_grad) == 0:
raise RuntimeError(
"none of output has stop_gradient=False, this recompute() is not necessary"
)

# actually backward
paddle.autograd.backward(forward_outputs_with_grad, backward_inputs)
grads = tuple(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, (core.VarBase, core.eager.Tensor)))
return grads


def _hp_recompute(function, *args):
# NODTE(shenliang03)The current hybrid parallel recompute has limitations.
# It cannot handle the following situations:
# 1. The calculation output of recompute, there are tensors that do not require gradients.
# 2. The forward output tensor has no gradient. This problem can be solved temporarily by detach().
# 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor

all_outputs = []
_HPRecomputeFunction.apply(function, all_outputs, *args)

if len(all_outputs) == 1:
return all_outputs[0]
else:
for output in all_outputs:
if paddle.is_tensor(output) and not is_float_tensor(output):
output.stop_gradient = True

return tuple(all_outputs)
42 changes: 0 additions & 42 deletions python/paddle/distributed/fleet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,46 +20,9 @@
from .meta_parallel import TensorParallel, model_parallel_random_seed
from .meta_parallel import PipelineParallel, ShardingParallel, PipelineParallelWithInterleave, PipelineLayer
from paddle.fluid import core
from paddle.distributed.fleet.utils.recompute import LegacyRecomputeFunction
from paddle.fluid.dygraph.varbase_patch_methods import _grad_scalar
from paddle.distributed import fleet


class _RecomputeModelWrapper(paddle.nn.Layer):

def __init__(self, model, segments=2, preserve_rng_state=True):
super(_RecomputeModelWrapper, self).__init__()
assert isinstance(model, paddle.nn.Sequential), (
"The model passed to RecomputeModelWrapper must be of type "
"paddle.nn.Sequential.")
self._model = model
self._segments = segments
self._preserve_rng_state = preserve_rng_state
self._layers = list(model.children())
self._segment_size = len(self._layers) // segments

def _run_func(self, begin, end):

def do_run(input):
for i in range(begin, end):
input = self._layers[i](input)
return input

return do_run

def _checkpoint(self, func, *args, **kwargs):
return LegacyRecomputeFunction.apply(func, self._preserve_rng_state,
*args)

def forward(self, input):
end = 0
for begin in range(0, self._segment_size * (self._segments - 1),
self._segment_size):
end = begin + self._segment_size
input = self._checkpoint(self._run_func(begin, end), input)
return self._run_func(end, len(self._layers))(input)


_grad_scalar = None


Expand Down Expand Up @@ -125,7 +88,6 @@ def forward(self, x):
return model

amp_enable = False
recompute_enable = False
strategy = fleet_env._user_defined_strategy
if strategy.amp == True:
amp_enable = True
Expand Down Expand Up @@ -154,10 +116,6 @@ def forward(self, x):
decr_every_n_nan_or_inf=decr_every_n_nan_or_inf,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)

if strategy.recompute == True:
recompute_enable = True
model = _RecomputeModelWrapper(model)

if strategy.heter_ccl_mode == True:
distributed_model = paddle.DataParallel(
model,
Expand Down
Loading