Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Custom Extension] Fix custom double_grad backward=None #49224

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions paddle/fluid/eager/custom_operator/custom_operator_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,17 +410,19 @@ RunCustomOpDoubleGradNode::operator()(

for (size_t i = 0; i < OutputMeta().size(); i++) {
if (map[1][0].find(i) != map[1][0].end()) {
int grad_output_idx = map[1][0][i];
VLOG(7) << "Insert grad outputs: " << i
<< " with size: " << OutputMeta()[i].size()
<< " to tmp_outputs: " << map[1][0][i];
for (size_t j = 0; j < OutputMeta()[i].size(); j++) {
outs[i].emplace_back(/* init it incase of copy nullptr of shared_ptr */
std::make_shared<phi::DenseTensor>(
phi::DataType::UNDEFINED),
egr::Controller::Instance().GenerateUniqueName(
"custom_tmp_grad"));
<< " with size: " << OutputMeta()[grad_output_idx].size()
<< " to tmp_outputs: " << grad_output_idx;
for (size_t j = 0; j < OutputMeta()[grad_output_idx].size(); j++) {
outs[grad_output_idx]
.emplace_back(/* init it incase of copy nullptr of shared_ptr */
std::make_shared<phi::DenseTensor>(
phi::DataType::UNDEFINED),
egr::Controller::Instance().GenerateUniqueName(
"custom_tmp_grad"));
}
tmp_outs[map[1][0][i]] = outs[i];
tmp_outs[grad_output_idx] = outs[grad_output_idx];
}
}
for (size_t i = 0; i < tmp_outs.size(); i++) {
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/custom_op/custom_relu_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ __global__ void relu_cuda_double_backward_kernel(const data_t* out_data,
data_t* ddout_data,
int64_t num) {
int64_t gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int64_t i = num; i < num; i += blockDim.x * gridDim.x) {
for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) {
ddout_data[i] = ddx_data[i] * (out_data[i] > static_cast<data_t>(0.)
? static_cast<data_t>(1.)
: static_cast<data_t>(0.));
Expand Down
21 changes: 14 additions & 7 deletions python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,23 @@ def custom_relu_double_grad_dynamic(func, device, dtype, np_x, use_func=True):
t = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)

out = func(t) if use_func else paddle.nn.functional.relu(t)
out.stop_gradient = False

dx = paddle.grad(
outputs=[out], inputs=[t], create_graph=True, retain_graph=True
outputs=out,
inputs=t,
grad_outputs=paddle.ones_like(t),
create_graph=True,
retain_graph=True,
)

dx[0].backward()
ddout = paddle.grad(
outputs=dx[0],
inputs=out.grad,
grad_outputs=paddle.ones_like(t),
create_graph=False,
)

assert dx[0].grad is not None
return dx[0].numpy(), dx[0].grad.numpy()
assert ddout[0].numpy() is not None
return dx[0].numpy(), ddout[0].numpy()


class TestNewCustomOpSetUpInstall(unittest.TestCase):
Expand Down Expand Up @@ -346,7 +353,7 @@ def test_static_save_and_run_inference_predictor(self):
)
paddle.disable_static()

def test_func_double_grad_dynamic(self):
def test_double_grad_dynamic(self):
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
for device in self.devices:
for dtype in self.dtypes:
Expand Down