Skip to content

Commit

Permalink
add _reset_grad_inplace_version (#41101)
Browse files Browse the repository at this point in the history
  • Loading branch information
pangyoki committed Mar 30, 2022
1 parent a5bfa79 commit cb8afc2
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 3 deletions.
24 changes: 24 additions & 0 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,27 @@ static PyObject* tensor_method_get_rows(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor__reset_grad_inplace_version(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
Py_ssize_t args_num = PyTuple_Size(args);
bool set_to_zero = true;
if (args_num == (Py_ssize_t)1) {
set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
}

paddle::experimental::Tensor* grad =
egr::EagerUtils::mutable_grad(self->tensor);
if (grad && grad->defined() && grad->is_dense_tensor() &&
grad->initialized()) {
grad->reset_inplace_version(set_to_zero);
}
Py_INCREF(Py_None);
return Py_None;
EAGER_CATCH_AND_THROW_RETURN_NULL
}

PyMethodDef variable_methods[] = {
{"numpy", (PyCFunction)(void (*)(void))tensor_method_numpy,
METH_VARARGS | METH_KEYWORDS, NULL},
Expand Down Expand Up @@ -1407,6 +1428,9 @@ PyMethodDef variable_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL},
{"rows", (PyCFunction)(void (*)(void))tensor_method_get_rows,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_reset_grad_inplace_version",
(PyCFunction)(void (*)(void))tensor__reset_grad_inplace_version,
METH_VARARGS | METH_KEYWORDS, NULL},
{NULL, NULL, 0, NULL}};

} // namespace pybind
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/api/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,11 @@ class PADDLE_API Tensor final {
*/
uint32_t current_inplace_version();

/**
* @brief Reset inplace version
*/
void reset_inplace_version(bool set_to_zero = false);

/* Part 10: Auto generated Tensor methods */

/* Part 11: Methods of converting SparseTensor and DenseTensor to each other
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/lib/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,5 +384,16 @@ uint32_t Tensor::current_inplace_version() {
return 0;
}

void Tensor::reset_inplace_version(bool set_to_zero) {
if (set_to_zero) {
if (is_dense_tensor()) {
auto &inplace_version_counter =
std::dynamic_pointer_cast<phi::DenseTensor>(impl_)
->InplaceVersionCounter();
inplace_version_counter.SetInplaceVersionToZero();
}
}
}

} // namespace experimental
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import paddle.fluid as fluid
from paddle import _C_ops
from paddle.fluid import framework
from paddle.fluid.framework import _test_eager_guard
import unittest
paddle.set_device('cpu')

Expand All @@ -32,7 +33,7 @@ def warp(*_):


class TestInplaceAndClearGradient(unittest.TestCase):
def test(self):
def func_test(self):
input_data = np.ones([1, 1])
w = paddle.to_tensor(input_data, 'float32', stop_gradient=False)

Expand All @@ -45,6 +46,11 @@ def test(self):
out.backward()
assert w.grad[0] == 0.15

def test(self):
with _test_eager_guard():
self.func_test()
self.func_test()


# Test 2
class Counter:
Expand All @@ -67,7 +73,7 @@ def warp(*_):


class TestInplaceClearGradAccumulation(unittest.TestCase):
def test(self):
def func_test(self):
input_data = np.ones([1, 1])
w = paddle.to_tensor(input_data, 'float32', stop_gradient=False)
c = Counter()
Expand All @@ -87,9 +93,14 @@ def test(self):
assert c.num_calls == 1
c.num_calls = 0

def test(self):
with _test_eager_guard():
self.func_test()
self.func_test()


class TestInplaceClearGradAccumulationAlt(unittest.TestCase):
def test(self):
def func_test(self):
input_data = np.ones([1, 1])
w = paddle.to_tensor(input_data, 'float32', stop_gradient=False)
out = _C_ops.scale(w, 'scale', 0.1)
Expand All @@ -100,6 +111,11 @@ def test(self):

assert w.grad._inplace_version() == 1

def test(self):
with _test_eager_guard():
self.func_test()
self.func_test()


if __name__ == '__main__':
unittest.main()

0 comments on commit cb8afc2

Please sign in to comment.