Skip to content

Commit

Permalink
fix bugs of reshape double grad infermeta (#41459) (#41493)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanRisheng committed Apr 8, 2022
1 parent 57fe4fc commit f196b84
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 5 deletions.
9 changes: 5 additions & 4 deletions paddle/fluid/operators/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -441,26 +441,27 @@ class ReshapeDoubleGradKernel {
public:
void operator()(const framework::ExecutionContext &ctx) const {
auto *dd_x = ctx.Input<framework::Tensor>("DDX");
auto *d_out = ctx.Input<framework::Tensor>("DOut");
auto *dd_out = ctx.Output<framework::Tensor>("DDOut");
dd_out->mutable_data(ctx.GetPlace(), dd_x->type());

if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
phi::ReshapeDoubleGradKernel(
static_cast<const phi::CPUContext &>(dev_ctx), *dd_x, dd_out);
static_cast<const phi::CPUContext &>(dev_ctx), *d_out, *dd_x, dd_out);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
phi::ReshapeDoubleGradKernel(
static_cast<const phi::GPUContext &>(dev_ctx), *dd_x, dd_out);
static_cast<const phi::GPUContext &>(dev_ctx), *d_out, *dd_x, dd_out);
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
phi::ReshapeDoubleGradKernel(
static_cast<const phi::XPUContext &>(dev_ctx), *dd_x, dd_out);
static_cast<const phi::XPUContext &>(dev_ctx), *d_out, *dd_x, dd_out);
}
#endif
}
Expand Down Expand Up @@ -658,7 +659,7 @@ REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp,

DECLARE_INFER_SHAPE_FUNCTOR(reshape2_grad_grad,
Reshape2DoubleGradInferShapeFunctor,
PD_INFER_META(phi::GeneralUnaryGradInferMeta));
PD_INFER_META(phi::ReshapeDoubleGradInferMeta));

REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
ops::ReshapeDoubleGradInplaceInferer,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,14 @@ void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx) {
dx->set_layout(out_grad.layout());
}

void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& x_grad_grad,
MetaTensor* out_grad_grad) {
if (out_grad_grad != nullptr) {
out_grad_grad->share_dims(out_grad);
}
}

void ScatterGradInferMeta(const MetaTensor& index,
const MetaTensor& updates,
const MetaTensor& out_grad,
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ void PoolGradInferMeta(const MetaTensor& x,

void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx);

void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& x_grad_grad,
MetaTensor* out_grad_grad);

void ScatterGradInferMeta(const MetaTensor& index,
const MetaTensor& updates,
const MetaTensor& out_grad,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/reshape_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ void ReshapeGradKernel(const Context& dev_ctx,

template <typename Context>
void ReshapeDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x_grad_grad,
DenseTensor* out_grad_grad) {
ReshapeGradKernel(dev_ctx, x_grad_grad, out_grad_grad);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/reshape_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ void ReshapeGradKernel(const Context& dev_ctx,

template <typename Context>
void ReshapeDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& x_grad_grad,
DenseTensor* out_grad_grad);

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/compat/reshape_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ KernelSignature ReshapeGradOpArgumentMapping(

KernelSignature ReshapeDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("reshape_double_grad", {"DDX"}, {}, {"DDOut"});
return KernelSignature("reshape_double_grad", {"DOut", "DDX"}, {}, {"DDOut"});
}

} // namespace phi
Expand Down

0 comments on commit f196b84

Please sign in to comment.