Skip to content

Commit

Permalink
update __all__.
Browse files Browse the repository at this point in the history
  • Loading branch information
wuhuachaocoding committed Sep 13, 2022
1 parent ab38767 commit a99c046
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
from paddle.fluid import core
from paddle import _C_ops, _legacy_C_ops

__all__ = [
"is_float_tensor", "get_tensor_dtype", "paddle_2_number", "number_2_dtype",
"get_tensor_bytes", "_all_gather"
]
__all__ = []

FLOAT_TYPE_DICT = {
paddle.float16: "float16",
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/fleet/recompute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
from .recompute import recompute, recompute_sequential
from .recompute_hybrid import recompute_hybrid

__all__ = ["recompute", "recompute_sequential", "recompute_hybrid"]
__all__ = []
5 changes: 1 addition & 4 deletions python/paddle/distributed/fleet/recompute/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@
ch.setFormatter(formatter)
logger.addHandler(ch)

__all__ = [
"recompute", "recompute_sequential", "swith_rng_state_tracker",
"check_recompute_necessary", "detach_variable"
]
__all__ = []


def detach_variable(inputs):
Expand Down
10 changes: 4 additions & 6 deletions python/paddle/distributed/fleet/recompute/recompute_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .recompute import check_recompute_necessary, detach_variable, swith_rng_state_tracker
from ..meta_parallel.pp_utils import utils

__all__ = ["recompute_hybrid"]
__all__ = []


def _split_activation(tensor, mp_group):
Expand Down Expand Up @@ -216,11 +216,10 @@ def recompute_hybrid(ctx, function, *args, **kwargs):
# 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor
Parameters:
ctx(dict): include 'mp_group', 'offload', 'partition' and 'preserve_rng_state' keys. the key 'mp_group' (Group), represents the avtivations are splitted
ctx(dict): include 'mp_group', 'offload', and 'partition' keys. the key 'mp_group' (Group), represents the avtivations are splitted
in which group. the key 'offload' (bool, optional, default=False), represents whether to offload to cpu. the key 'partition' (bool, optional, default=False),
represents whether to split activations in the mp_group. the key 'preserve_rng_state' (bool, optional, default=True) indicate whether to save the forward rng.
If it is True, then the last forward rng value will be restored when the forward recalculation of backpropagation is performed. and some keys such as 'segments',
are invalid here, they are useful in 'recompute_sequential' API.
represents whether to split activations in the mp_group. and some keys such as 'segments' and 'preserve_rng_state' are invalid here, they are useful in
'recompute_sequential' API.
function(paddle.nn.Layer): layer of sequence of layers that describes part of forward pass of the model
whose intermediate activations will be released to save memory in forward stage and will be recomputed
in backward stage for gradient calculation.
Expand All @@ -237,7 +236,6 @@ def recompute_hybrid(ctx, function, *args, **kwargs):

offload = ctx.get('offload', False)
partition = ctx.get('partition', False)
preserve_rng_state = ctx.get('preserve_rng_state', True)

all_outputs = []
_HPRecomputeFunction.apply(function, all_outputs, mp_group, offload,
Expand Down

0 comments on commit a99c046

Please sign in to comment.