Skip to content

Commit

Permalink
unify inplace_version checking log in new and old dygraph framework (#…
Browse files Browse the repository at this point in the history
…41209)

* change inplace_version checking log

* fix
  • Loading branch information
pangyoki committed Apr 1, 2022
1 parent c86e3a1 commit 93cb235
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 47 deletions.
20 changes: 10 additions & 10 deletions paddle/fluid/eager/tensor_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,25 +121,25 @@ class TensorWrapper {
static_cast<phi::DenseTensor*>(intermidiate_tensor_.impl().get());
auto& inplace_version_counter = dense_tensor->InplaceVersionCounter();

uint32_t current_inplace_version =
inplace_version_counter.CurrentVersion();
uint32_t wrapper_version_snapshot = inplace_version_snapshot_;
uint32_t tensor_version = inplace_version_counter.CurrentVersion();
PADDLE_ENFORCE_EQ(
current_inplace_version, inplace_version_snapshot_,
tensor_version, wrapper_version_snapshot,
paddle::platform::errors::PermissionDenied(
"Tensor '%s' used in gradient computation has been "
"modified by an inplace operation. "
"Its version is %d but the expected version is %d. "
"Please fix your code to void calling an inplace operator "
"after using the Tensor which will used in gradient "
"computation.",
intermidiate_tensor_.name(), current_inplace_version,
inplace_version_snapshot_));
VLOG(6) << " The inplace_version_snapshot_ of Tensor '"
intermidiate_tensor_.name(), tensor_version,
wrapper_version_snapshot));
VLOG(6) << " The wrapper_version_snapshot of Tensor '"
<< intermidiate_tensor_.name() << "' is [ "
<< inplace_version_snapshot_ << " ]";
VLOG(6) << " The current_inplace_version of Tensor '"
<< intermidiate_tensor_.name() << "' is [ "
<< current_inplace_version << " ]";
<< wrapper_version_snapshot << " ]";
VLOG(6) << " The tensor_version of Tensor '"
<< intermidiate_tensor_.name() << "' is [ " << tensor_version
<< " ]";
}
}

Expand Down
34 changes: 10 additions & 24 deletions python/paddle/fluid/tests/unittests/test_inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,11 @@ def func_test_backward_error(self):
var_d = var_b**2

loss = paddle.nn.functional.relu(var_c + var_d)
if in_dygraph_mode():
with self.assertRaisesRegexp(
RuntimeError,
"received current_inplace_version:{} != inplace_version_snapshot_:{}".
format(1, 0)):
loss.backward()
else:
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()

def test_backward_error(self):
with _test_eager_guard():
Expand Down Expand Up @@ -203,18 +196,11 @@ def func_test_backward_error(self):
self.inplace_api_processing(var_b)

loss = paddle.nn.functional.relu(var_c)
if in_dygraph_mode():
with self.assertRaisesRegexp(
RuntimeError,
"received current_inplace_version:{} != inplace_version_snapshot_:{}".
format(1, 0)):
loss.backward()
else:
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()

def test_backward_error(self):
with _test_eager_guard():
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_pylayer_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def forward(self, data):
z = layer(data)
with self.assertRaisesRegexp(
RuntimeError,
"received current_inplace_version:{} != inplace_version_snapshot_:{}".
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
z.backward()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,11 @@ def func_test_backward_error(self):
view_var_b[0] = 2. # var_b is modified inplace

loss = paddle.nn.functional.relu(var_c)
if in_dygraph_mode():
with self.assertRaisesRegexp(
RuntimeError,
"received current_inplace_version:{} != inplace_version_snapshot_:{}".
format(1, 0)):
loss.backward()
else:
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()

def test_backward_error(self):
with _test_eager_guard():
Expand Down

0 comments on commit 93cb235

Please sign in to comment.