Skip to content

Commit

Permalink
add enable_grad, reafctor no_grad, set_grad_enabled (#50560)
Browse files Browse the repository at this point in the history
* add enable_grad, reafctor no_grad, set_grad_enabled

* fix bug

* fix bug

* format

* fix bug

* format

* format

* fix doc

* fix

* fix

* fix bug

* fix comment
  • Loading branch information
zh794390558 committed Feb 22, 2023
1 parent a35dbc2 commit 499b7f8
Show file tree
Hide file tree
Showing 6 changed files with 307 additions and 91 deletions.
1 change: 1 addition & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@

from .autograd import grad # noqa: F401
from .autograd import no_grad # noqa: F401
from .autograd import enable_grad # noqa:F401
from .autograd import set_grad_enabled # noqa: F401
from .autograd import is_grad_enabled # noqa: F401
from .framework import save # noqa: F401
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.

from ..fluid.dygraph.base import grad # noqa: F401
from ..fluid.dygraph.base import enable_grad # noqa: F401
from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from ..framework import is_grad_enabled, set_grad_enabled # noqa: F401
from ..fluid.dygraph.base import is_grad_enabled # noqa: F401
from ..fluid.dygraph.base import set_grad_enabled # noqa: F401
from . import backward_mode # noqa: F401
from .backward_mode import backward # noqa: F401
from .py_layer import PyLayer as PyLayer # noqa: F401
Expand Down
179 changes: 156 additions & 23 deletions python/paddle/fluid/dygraph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,114 @@ def __impl__(func, *args, **kwargs):
return __impl__(func)


class no_grad_:
class _DecoratorContextManager:
"""Allow a context manager to be used as a decorator"""

def __call__(self, func):
@decorator.decorator
def _decorate_function(func, *args, **kwargs):
with self:
return func(*args, **kwargs)

@decorator.decorator
def _decorate_generator(func, *args, **kwargs):
gen = func(*args, **kwargs)
with self:
for x in gen:
yield x

if inspect.isgeneratorfunction(func):
return _decorate_generator(func)
else:
return _decorate_function(func)

def __enter__(self):
raise NotImplementedError

def __exit__(self, exc_type, exc_value, traceback):
raise NotImplementedError

def clone(self):
# override this method if your children class takes __init__ parameters
return self.__class__()


def is_grad_enabled():
"""
Returns whether current dygraph gradient calculation mode is enabled.
Returns:
bool: True if current dygraph gradient calculation mode is enabled, otherwise false.
Examples:
.. code-block:: python
import paddle
# Dygraph gradient calculation mode is enabled by default.
paddle.is_grad_enabled() # True
with paddle.set_grad_enabled(False):
paddle.is_grad_enabled() # False
paddle.enable_static()
paddle.is_grad_enabled() # False
"""
tracer = framework._dygraph_tracer()
return tracer._has_grad if tracer else False


def _set_grad_enabled(mode):
tracer = framework._dygraph_tracer()
if tracer:
tracer._has_grad = mode


class set_grad_enabled(_DecoratorContextManager):
"""
Create a context which enables or disables dygraph gradient calculation.
Args:
mode(bool): whether to enable (`True`), or disable (`False`) grad.
Returns:
None.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([1.], stop_gradient=False)
is_train = False
with paddle.set_grad_enabled(is_train):
y = x * 2
assert(y.stop_gradient == True)
paddle.set_grad_enabled(True)
y = x * 2
assert(y.stop_gradient == False)
paddle.set_grad_enabled(False)
y = x * 2
assert(y.stop_gradient == True)
"""

def __init__(self, mode):
self.prev = is_grad_enabled()
_set_grad_enabled(mode)
self.mode = mode

def __enter__(self):
...

def __exit__(self, *args):
_set_grad_enabled(self.prev)

def clone(self):
return self.__class__(self.mode)


class no_grad_(_DecoratorContextManager):
"""
:api_attr: imperative
Expand Down Expand Up @@ -389,34 +496,60 @@ def test_layer():
test_layer()
"""

def __call__(self, func):
@decorator.decorator
def _decorate_function(func, *args, **kwargs):
with self:
return func(*args, **kwargs)
def __enter__(self):
self.prev = is_grad_enabled()
_set_grad_enabled(False)

@decorator.decorator
def _decorate_generator(func, *args, **kwargs):
gen = func(*args, **kwargs)
with self:
for x in gen:
yield x
def __exit__(self, *args):
_set_grad_enabled(self.prev)

if inspect.isgeneratorfunction(func):
return _decorate_generator(func)
else:
return _decorate_function(func)

class enable_grad(_DecoratorContextManager):
"""
:api_attr: imperative
Create a context which enable dygraph gradient calculation,
if it has been disabled by `no_grad` or `set_grad_enabled`.
In this mode, the result of every computation will have `stop_gradient` set
to `False`.
Also functions as a decorator. (Make sure to use an instance.)
Examples:
.. code-block:: python
import paddle
# use as generator
x = paddle.to_tensor([1.], stop_gradient=False)
with paddle.no_grad():
with paddle.enable_grad():
y = x * 2
assert(y.stop_gradient == False)
y.backward()
assert(x.grad is not None)
# use as decorator
@paddle.enable_grad()
def double(x):
return x * 2
with paddle.no_grad():
z = double(x)
assert(z.stop_gradient == False)
"""

def __enter__(self):
tracer = framework._dygraph_tracer()
if tracer:
self.orig = tracer._has_grad
tracer._has_grad = False
self.prev = is_grad_enabled()
_set_grad_enabled(True)

def __exit__(self, *args):
tracer = framework._dygraph_tracer()
if tracer:
tracer._has_grad = self.orig
_set_grad_enabled(self.prev)


@signature_safe_contextmanager
Expand Down
147 changes: 147 additions & 0 deletions python/paddle/fluid/tests/unittests/test_imperative_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_main(self):

self.tracer = framework._dygraph_tracer()
self.tracer._train_mode = True
self.tracer._has_grad = True

self.assertEqual(self.no_grad_func(1), 1)
self.assertEqual(self.no_grad_func.__name__, "no_grad_func")
Expand Down Expand Up @@ -123,5 +124,151 @@ def test_wrapped_gen():
self.assertEqual(a, b)


class TestEnableGradClass(unittest.TestCase):
@paddle.enable_grad()
def enable_grad_func(self, a):
self.assertEqual(self.tracer._train_mode, True)
self.assertEqual(self.tracer._has_grad, True)
return a

def test_main(self):
paddle.disable_static()

self.tracer = framework._dygraph_tracer()
self.tracer._train_mode = True
self.tracer._has_grad = False

self.assertEqual(self.enable_grad_func(1), 1)
self.assertEqual(self.enable_grad_func.__name__, "enable_grad_func")

def need_enable_grad_func(a, b=1):
return a + b

decorated_func = paddle.enable_grad()(need_enable_grad_func)
self.assertEqual(
str(inspect.getfullargspec(decorated_func)),
str(inspect.getfullargspec(need_enable_grad_func)),
)

def test_gen():
for i in range(3):
yield i

a = 0
for i in test_gen():
a += i

@paddle.enable_grad()
def test_wrapped_gen():
for i in range(3):
yield i

b = 0
for i in test_wrapped_gen():
b += i

self.assertEqual(a, b)

def test_stop_gradient(self):
x = paddle.to_tensor([1.0], stop_gradient=False)
with paddle.no_grad():
with paddle.enable_grad():
y = x * 2
self.assertTrue(y.stop_gradient is False)
y.backward()
self.assertTrue(x.grad is not None)

# use as decorator
@paddle.enable_grad()
def double(x):
return x * 2

with paddle.no_grad():
z = double(x)

self.assertTrue(z.stop_gradient is False)


class TestSetGradEnabledClass(unittest.TestCase):
@paddle.set_grad_enabled(True)
def enable_grad_func(self, a):
self.assertEqual(self.tracer._train_mode, True)
self.assertEqual(self.tracer._has_grad, True)
return a

def test_main(self):
paddle.disable_static()

self.tracer = framework._dygraph_tracer()
self.tracer._train_mode = True

self.assertEqual(self.enable_grad_func(1), 1)
self.assertEqual(self.enable_grad_func.__name__, "enable_grad_func")

def need_enable_grad_func(a, b=1):
return a + b

decorated_func = paddle.set_grad_enabled(True)(need_enable_grad_func)
self.assertEqual(
str(inspect.getfullargspec(decorated_func)),
str(inspect.getfullargspec(need_enable_grad_func)),
)

def test_gen():
for i in range(3):
yield i

a = 0
for i in test_gen():
a += i

@paddle.set_grad_enabled(True)
def test_wrapped_gen():
for i in range(3):
yield i

b = 0
for i in test_wrapped_gen():
b += i

self.assertEqual(a, b)

def test_stop_gradient(self):
x = paddle.to_tensor([1.0], stop_gradient=False)
is_train = False
with paddle.set_grad_enabled(is_train):
y = x * 2
self.assertTrue(y.stop_gradient is True)

paddle.set_grad_enabled(True)
y = x * 2
self.assertTrue(y.stop_gradient is False)

paddle.set_grad_enabled(False)
y = x * 2
self.assertTrue(y.stop_gradient is True)


class TestIsGradEnabledClass(unittest.TestCase):
def test_main(self):
paddle.disable_static()

self.tracer = framework._dygraph_tracer()
self.tracer._train_mode = True
self.tracer._has_grad = True

# Dygraph gradient calculation mode is enabled by default.
flag = paddle.is_grad_enabled()
self.assertTrue(flag is True)

with paddle.set_grad_enabled(False):
flag = paddle.is_grad_enabled()
self.assertTrue(flag is False)

flag = paddle.is_grad_enabled()
self.assertTrue(flag is True)
paddle.enable_static()


if __name__ == '__main__':
unittest.main()
2 changes: 0 additions & 2 deletions python/paddle/framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from .random import seed # noqa: F401
from .framework import get_default_dtype # noqa: F401
from .framework import set_default_dtype # noqa: F401
from .framework import set_grad_enabled # noqa: F401
from .framework import is_grad_enabled # noqa: F401

from ..fluid.param_attr import ParamAttr # noqa: F401
from ..fluid.core import CPUPlace # noqa: F401
Expand Down
Loading

0 comments on commit 499b7f8

Please sign in to comment.