Skip to content

Commit

Permalink
[Custom Extension] Fix custom double_grad backward=None (PaddlePaddle…
Browse files Browse the repository at this point in the history
…#49224)

* fix custom double_grad backward=None

* fix custom_relu.cu bug && polish testcase of double_grad

* remove old dynamic graph test
  • Loading branch information
jiahy0825 committed Dec 23, 2022
1 parent 9fafa06 commit 87e2210
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 102 deletions.
20 changes: 11 additions & 9 deletions paddle/fluid/eager/custom_operator/custom_operator_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,17 +410,19 @@ RunCustomOpDoubleGradNode::operator()(

for (size_t i = 0; i < OutputMeta().size(); i++) {
if (map[1][0].find(i) != map[1][0].end()) {
int grad_output_idx = map[1][0][i];
VLOG(7) << "Insert grad outputs: " << i
<< " with size: " << OutputMeta()[i].size()
<< " to tmp_outputs: " << map[1][0][i];
for (size_t j = 0; j < OutputMeta()[i].size(); j++) {
outs[i].emplace_back(/* init it incase of copy nullptr of shared_ptr */
std::make_shared<phi::DenseTensor>(
phi::DataType::UNDEFINED),
egr::Controller::Instance().GenerateUniqueName(
"custom_tmp_grad"));
<< " with size: " << OutputMeta()[grad_output_idx].size()
<< " to tmp_outputs: " << grad_output_idx;
for (size_t j = 0; j < OutputMeta()[grad_output_idx].size(); j++) {
outs[grad_output_idx]
.emplace_back(/* init it incase of copy nullptr of shared_ptr */
std::make_shared<phi::DenseTensor>(
phi::DataType::UNDEFINED),
egr::Controller::Instance().GenerateUniqueName(
"custom_tmp_grad"));
}
tmp_outs[map[1][0][i]] = outs[i];
tmp_outs[grad_output_idx] = outs[grad_output_idx];
}
}
for (size_t i = 0; i < tmp_outs.size(); i++) {
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/custom_op/custom_relu_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ __global__ void relu_cuda_double_backward_kernel(const data_t* out_data,
data_t* ddout_data,
int64_t num) {
int64_t gid = blockIdx.x * blockDim.x + threadIdx.x;
for (int64_t i = num; i < num; i += blockDim.x * gridDim.x) {
for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) {
ddout_data[i] = ddx_data[i] * (out_data[i] > static_cast<data_t>(0.)
? static_cast<data_t>(1.)
: static_cast<data_t>(0.));
Expand Down
Loading

0 comments on commit 87e2210

Please sign in to comment.