Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add _reset_grad_inplace_version in eager mode #41101

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,27 @@ static PyObject* tensor_method_get_rows(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor__reset_grad_inplace_version(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
Py_ssize_t args_num = PyTuple_Size(args);
bool set_to_zero = true;
if (args_num == (Py_ssize_t)1) {
set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
}

paddle::experimental::Tensor* grad =
egr::EagerUtils::mutable_grad(self->tensor);
if (grad && grad->defined() && grad->is_dense_tensor() &&
grad->initialized()) {
grad->reset_inplace_version(set_to_zero);
}
Py_INCREF(Py_None);
return Py_None;
EAGER_CATCH_AND_THROW_RETURN_NULL
}

PyMethodDef variable_methods[] = {
{"numpy", (PyCFunction)(void (*)(void))tensor_method_numpy,
METH_VARARGS | METH_KEYWORDS, NULL},
Expand Down Expand Up @@ -1292,6 +1313,9 @@ PyMethodDef variable_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL},
{"rows", (PyCFunction)(void (*)(void))tensor_method_get_rows,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_reset_grad_inplace_version",
(PyCFunction)(void (*)(void))tensor__reset_grad_inplace_version,
METH_VARARGS | METH_KEYWORDS, NULL},
{NULL, NULL, 0, NULL}};

} // namespace pybind
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/api/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,11 @@ class PADDLE_API Tensor final {
*/
uint32_t current_inplace_version();

/**
* @brief Reset inplace version
*/
void reset_inplace_version(bool set_to_zero = false);

/* Part 10: Auto generated Tensor methods */

/* Part 11: Methods of converting SparseTensor and DenseTensor to each other
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/lib/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,5 +390,16 @@ uint32_t Tensor::current_inplace_version() {
return 0;
}

void Tensor::reset_inplace_version(bool set_to_zero) {
if (set_to_zero) {
if (is_dense_tensor()) {
auto &inplace_version_counter =
std::dynamic_pointer_cast<phi::DenseTensor>(impl_)
->InplaceVersionCounter();
inplace_version_counter.SetInplaceVersionToZero();
}
}
}

} // namespace experimental
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import paddle.fluid as fluid
from paddle import _C_ops
from paddle.fluid import framework
from paddle.fluid.framework import _test_eager_guard
import unittest
paddle.set_device('cpu')

Expand All @@ -32,7 +33,7 @@ def warp(*_):


class TestInplaceAndClearGradient(unittest.TestCase):
def test(self):
def func_test(self):
input_data = np.ones([1, 1])
w = paddle.to_tensor(input_data, 'float32', stop_gradient=False)

Expand All @@ -45,6 +46,11 @@ def test(self):
out.backward()
assert w.grad[0] == 0.15

def test(self):
with _test_eager_guard():
self.func_test()
self.func_test()


# Test 2
class Counter:
Expand All @@ -67,7 +73,7 @@ def warp(*_):


class TestInplaceClearGradAccumulation(unittest.TestCase):
def test(self):
def func_test(self):
input_data = np.ones([1, 1])
w = paddle.to_tensor(input_data, 'float32', stop_gradient=False)
c = Counter()
Expand All @@ -87,9 +93,14 @@ def test(self):
assert c.num_calls == 1
c.num_calls = 0

def test(self):
with _test_eager_guard():
self.func_test()
self.func_test()


class TestInplaceClearGradAccumulationAlt(unittest.TestCase):
def test(self):
def func_test(self):
input_data = np.ones([1, 1])
w = paddle.to_tensor(input_data, 'float32', stop_gradient=False)
out = _C_ops.scale(w, 'scale', 0.1)
Expand All @@ -100,6 +111,11 @@ def test(self):

assert w.grad._inplace_version() == 1

def test(self):
with _test_eager_guard():
self.func_test()
self.func_test()


if __name__ == '__main__':
unittest.main()