From 468ace3a5c5b16885c731a5fec058dcdc96b4663 Mon Sep 17 00:00:00 2001 From: weishengyu Date: Wed, 1 Sep 2021 22:14:27 +0800 Subject: [PATCH 1/6] add pool2d grad grad --- paddle/fluid/operators/pool_cudnn_op.cu.cc | 17 ++++++++++++++ paddle/fluid/operators/pool_op.cc | 24 ++++++++++++++++++- paddle/fluid/operators/pool_op.cu | 7 ++++++ paddle/fluid/operators/pool_op.h | 27 ++++++++++++++++++++++ 4 files changed, 74 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/pool_cudnn_op.cu.cc b/paddle/fluid/operators/pool_cudnn_op.cu.cc index 1bdb3728f538e..164c271ec2b22 100644 --- a/paddle/fluid/operators/pool_cudnn_op.cu.cc +++ b/paddle/fluid/operators/pool_cudnn_op.cu.cc @@ -505,6 +505,19 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel { } }; +template +class PoolCUDNNGradGradOpKernel : public PoolCUDNNOpKernel{ + public: + void Compute(const framework::ExecutionContext &ctx) const override{ + std::string pooling_type = context.Attr("pooling_type"); + if (pooling_type == "max") { + PADDLE_THROW(platform::errors::InvalidArgument( + "Pool op grad grad only supports avgpool.")); + } + else PoolCUDNNOpKernel::Compute(ctx); + } +}; + } // namespace operators } // namespace paddle @@ -534,6 +547,10 @@ REGISTER_OP_KERNEL(pool2d_grad, CUDNN, plat::CUDAPlace, ops::PoolCUDNNGradOpKernel, ops::PoolCUDNNGradOpKernel, ops::PoolCUDNNGradOpKernel); +REGISTER_OP_KERNEL(pool2d_grad_grad, CUDNN, plat::CUDAPlace, + ops::PoolCUDNNGradGradOpKernel, + ops::PoolCUDNNGradGradOpKernel, + ops::PoolCUDNNGradGradOpKernel); REGISTER_OP_KERNEL(pool3d, CUDNN, plat::CUDAPlace, ops::PoolCUDNNOpKernel, diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index feb47a73ee405..f335117f9b1a5 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -464,6 +464,22 @@ The input(X) size and output(Out) size may be different. )DOC"); } + +template +class Pool2dOpGradGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + grad_op->SetType("pool2d_grad_grad"); + grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X"))); + grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out"))); + grad_op->SetAttrMap(this->Attrs()); + } +}; + + class PoolOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: std::unordered_map& GetInputOutputWithSameType() @@ -680,7 +696,9 @@ REGISTER_OPERATOR( pool2d, ops::PoolOp, ops::Pool2dOpMaker, ops::PoolOpInferVarType, paddle::framework::DefaultGradOpMaker, paddle::framework::DefaultGradOpMaker); -REGISTER_OPERATOR(pool2d_grad, ops::PoolOpGrad); +REGISTER_OPERATOR(pool2d_grad, ops::PoolOpGrad, + ops::Pool2dOpGradGradMaker, + ops::Pool2dOpGradGradMaker); REGISTER_OP_CPU_KERNEL( pool2d, ops::PoolKernel, @@ -688,6 +706,10 @@ REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL( pool2d_grad, ops::PoolGradKernel, ops::PoolGradKernel); +REGISTER_OP_CPU_KERNEL( + pool2d_grad_grad, + ops::PoolGradGradKernel, + ops::PoolGradGradKernel); REGISTER_OPERATOR( pool3d, ops::PoolOp, ops::Pool3dOpMaker, ops::PoolOpInferVarType, diff --git a/paddle/fluid/operators/pool_op.cu b/paddle/fluid/operators/pool_op.cu index 6b1e9f93033aa..0608a15ff1d14 100644 --- a/paddle/fluid/operators/pool_op.cu +++ b/paddle/fluid/operators/pool_op.cu @@ -28,6 +28,13 @@ REGISTER_OP_CUDA_KERNEL( ops::PoolGradKernel); +REGISTER_OP_CUDA_KERNEL( + pool2d_grad_grad, + ops::PoolGradGradKernel, + ops::PoolGradGradKernel, + ops::PoolGradGradKernel); + REGISTER_OP_CUDA_KERNEL( pool3d, ops::PoolKernel, ops::PoolKernel, diff --git a/paddle/fluid/operators/pool_op.h b/paddle/fluid/operators/pool_op.h index e84c92d9a1624..d744ca710f0b8 100644 --- a/paddle/fluid/operators/pool_op.h +++ b/paddle/fluid/operators/pool_op.h @@ -73,6 +73,18 @@ class PoolOpGrad : public framework::OperatorWithKernel { const framework::OpKernelType& expected_kernel_type) const override; }; + +class PoolOpGradGrad : public PoolOp{ + public: + using framework::OperatorWithKernel::OperatorWithKernel; + using PoolOp::InferShape; + + protected: + using PoolOp::GetExpectedKernelType; + using PoolOp::GetKernelTypeForVar; +}; + + class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override; @@ -357,5 +369,20 @@ class PoolGradKernel : public framework::OpKernel { } }; + +template +class PoolGradGradKernel : public PoolKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + std::string pooling_type = context.Attr("pooling_type"); + if (pooling_type == "max") { + PADDLE_THROW(platform::errors::InvalidArgument( + "Pool op grad grad only supports avgpool.")); + } + else PoolKernel::Compute(context); + } +}; + + } // namespace operators } // namespace paddle From bb5a4659a64d8798b1e95a62c2ba2a8e417c417d Mon Sep 17 00:00:00 2001 From: weishengyu Date: Thu, 2 Sep 2021 11:10:27 +0800 Subject: [PATCH 2/6] dbg --- paddle/fluid/operators/pool_cudnn_op.cu.cc | 4 ++-- paddle/fluid/operators/pool_op.cc | 1 + paddle/fluid/operators/pool_op.h | 13 +------------ 3 files changed, 4 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/pool_cudnn_op.cu.cc b/paddle/fluid/operators/pool_cudnn_op.cu.cc index 164c271ec2b22..079fb6c6fd8a9 100644 --- a/paddle/fluid/operators/pool_cudnn_op.cu.cc +++ b/paddle/fluid/operators/pool_cudnn_op.cu.cc @@ -509,12 +509,12 @@ template class PoolCUDNNGradGradOpKernel : public PoolCUDNNOpKernel{ public: void Compute(const framework::ExecutionContext &ctx) const override{ - std::string pooling_type = context.Attr("pooling_type"); + std::string pooling_type = ctx.Attr("pooling_type"); if (pooling_type == "max") { PADDLE_THROW(platform::errors::InvalidArgument( "Pool op grad grad only supports avgpool.")); } - else PoolCUDNNOpKernel::Compute(ctx); + else PoolCUDNNOpKernel::Compute(ctx); } }; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index f335117f9b1a5..c0276cfceef55 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -699,6 +699,7 @@ REGISTER_OPERATOR( REGISTER_OPERATOR(pool2d_grad, ops::PoolOpGrad, ops::Pool2dOpGradGradMaker, ops::Pool2dOpGradGradMaker); +REGISTER_OPERATOR(pool2d_grad_grad, ops::PoolOp); REGISTER_OP_CPU_KERNEL( pool2d, ops::PoolKernel, diff --git a/paddle/fluid/operators/pool_op.h b/paddle/fluid/operators/pool_op.h index d744ca710f0b8..e8bdf557d040d 100644 --- a/paddle/fluid/operators/pool_op.h +++ b/paddle/fluid/operators/pool_op.h @@ -74,17 +74,6 @@ class PoolOpGrad : public framework::OperatorWithKernel { }; -class PoolOpGradGrad : public PoolOp{ - public: - using framework::OperatorWithKernel::OperatorWithKernel; - using PoolOp::InferShape; - - protected: - using PoolOp::GetExpectedKernelType; - using PoolOp::GetKernelTypeForVar; -}; - - class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override; @@ -379,7 +368,7 @@ class PoolGradGradKernel : public PoolKernel { PADDLE_THROW(platform::errors::InvalidArgument( "Pool op grad grad only supports avgpool.")); } - else PoolKernel::Compute(context); + else PoolKernel::Compute(context); } }; From 2b407989b2e8c35a77f037def59bc7d40cdd0b35 Mon Sep 17 00:00:00 2001 From: weishengyu Date: Thu, 2 Sep 2021 14:32:49 +0800 Subject: [PATCH 3/6] add unittest --- .../fluid/tests/unittests/test_nn_grad.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_nn_grad.py b/python/paddle/fluid/tests/unittests/test_nn_grad.py index 33d313e709e92..7270bab59f236 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_nn_grad.py @@ -381,5 +381,29 @@ def test_grad(self): self.func(p) +class TestAvgPool2DDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + input_NCHW = fluid.layers.data( + name="input_NCHW", + shape=[2, 3, 5, 5], + append_batch_size=False, + dtype="float32") + + input_NCHW.persistable = True + y = layers.pool2d(input_NCHW, pool_size=2, pool_type="avg") + x_arr = np.random.uniform(-1, 1, [2, 3, 5, 5]).astype(np.float32) + + gradient_checker.double_grad_check( + [input_NCHW], y, x_init=x_arr, place=place, eps=0.05) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + if __name__ == "__main__": unittest.main() From 88cdd4cc7fe59cfb1377220f01d842a52fa88c28 Mon Sep 17 00:00:00 2001 From: weishengyu Date: Thu, 2 Sep 2021 07:13:49 +0000 Subject: [PATCH 4/6] update format --- paddle/fluid/operators/pool_cudnn_op.cu.cc | 7 ++++--- paddle/fluid/operators/pool_op.cc | 2 -- paddle/fluid/operators/pool_op.cu | 2 +- paddle/fluid/operators/pool_op.h | 6 ++---- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/pool_cudnn_op.cu.cc b/paddle/fluid/operators/pool_cudnn_op.cu.cc index 079fb6c6fd8a9..8fcd40a9a2df4 100644 --- a/paddle/fluid/operators/pool_cudnn_op.cu.cc +++ b/paddle/fluid/operators/pool_cudnn_op.cu.cc @@ -506,15 +506,16 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel { }; template -class PoolCUDNNGradGradOpKernel : public PoolCUDNNOpKernel{ +class PoolCUDNNGradGradOpKernel : public PoolCUDNNOpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override{ + void Compute(const framework::ExecutionContext &ctx) const override { std::string pooling_type = ctx.Attr("pooling_type"); if (pooling_type == "max") { PADDLE_THROW(platform::errors::InvalidArgument( "Pool op grad grad only supports avgpool.")); + } else { + PoolCUDNNOpKernel::Compute(ctx); } - else PoolCUDNNOpKernel::Compute(ctx); } }; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index c0276cfceef55..e8e6125335a9f 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -464,7 +464,6 @@ The input(X) size and output(Out) size may be different. )DOC"); } - template class Pool2dOpGradGradMaker : public framework::SingleGradOpMaker { public: @@ -479,7 +478,6 @@ class Pool2dOpGradGradMaker : public framework::SingleGradOpMaker { } }; - class PoolOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: std::unordered_map& GetInputOutputWithSameType() diff --git a/paddle/fluid/operators/pool_op.cu b/paddle/fluid/operators/pool_op.cu index 0608a15ff1d14..069ce0c1fda85 100644 --- a/paddle/fluid/operators/pool_op.cu +++ b/paddle/fluid/operators/pool_op.cu @@ -33,7 +33,7 @@ REGISTER_OP_CUDA_KERNEL( ops::PoolGradGradKernel, ops::PoolGradGradKernel, ops::PoolGradGradKernel); + paddle::platform::float16>); REGISTER_OP_CUDA_KERNEL( pool3d, ops::PoolKernel, diff --git a/paddle/fluid/operators/pool_op.h b/paddle/fluid/operators/pool_op.h index e8bdf557d040d..9ee8eab1a7922 100644 --- a/paddle/fluid/operators/pool_op.h +++ b/paddle/fluid/operators/pool_op.h @@ -73,7 +73,6 @@ class PoolOpGrad : public framework::OperatorWithKernel { const framework::OpKernelType& expected_kernel_type) const override; }; - class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override; @@ -358,7 +357,6 @@ class PoolGradKernel : public framework::OpKernel { } }; - template class PoolGradGradKernel : public PoolKernel { public: @@ -367,11 +365,11 @@ class PoolGradGradKernel : public PoolKernel { if (pooling_type == "max") { PADDLE_THROW(platform::errors::InvalidArgument( "Pool op grad grad only supports avgpool.")); + } else { + PoolKernel::Compute(context); } - else PoolKernel::Compute(context); } }; - } // namespace operators } // namespace paddle From ea34038abe2ba0a7a4e53ee997e7329c1533da08 Mon Sep 17 00:00:00 2001 From: weishengyu Date: Thu, 2 Sep 2021 12:58:08 +0000 Subject: [PATCH 5/6] add more unittests --- .../fluid/tests/unittests/test_nn_grad.py | 79 +++++++++++++------ 1 file changed, 56 insertions(+), 23 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_nn_grad.py b/python/paddle/fluid/tests/unittests/test_nn_grad.py index 7270bab59f236..c9950386383ec 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_nn_grad.py @@ -340,38 +340,46 @@ def test_grad(self): self.func(p) -class TestConstantPadDoubleGradCheckCase1(TestConstantPadDoubleGradCheck): +class TestAvgPool2DDoubleGradCheckCase1(unittest.TestCase): @prog_scope() def func(self, place): - x_shape = [2, 3, 4, 5] - pad = [1, 0, 1, 0, 1, 0, 1, 0] - dtype = np.float64 + input_NCHW = fluid.layers.data( + name="input_NCHW", + shape=[2, 3, 5, 5], + append_batch_size=False, + dtype="float32") - x = layers.data('x', x_shape, False, dtype) - x.persistable = True - out = paddle.nn.functional.pad(x, pad) - x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype) + input_NCHW.persistable = True + y = layers.pool2d(input_NCHW, pool_size=2, pool_type="avg") + x_arr = np.random.uniform(-1, 1, [2, 3, 5, 5]).astype(np.float32) - gradient_checker.double_grad_check([x], out, x_init=x_arr, place=place) + gradient_checker.double_grad_check( + [input_NCHW], y, x_init=x_arr, place=place, eps=0.05) + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) -class TestConcatDoubleGradCheck(unittest.TestCase): + +class TestAvgPool2DDoubleGradCheckCase2(unittest.TestCase): @prog_scope() def func(self, place): - x_shape = [2, 3, 4, 5] - pad = [1, 1, 1, 1] - dtype = np.float64 + input_NHWC = fluid.layers.data( + name="input_NHWC", + shape=[2, 5, 5, 3], + append_batch_size=False, + dtype="float32") - x1 = layers.data('x', x_shape, False, dtype) - x2 = layers.data('x', x_shape, False, dtype) - x1.persistable = True - x2.persistable = True - out = paddle.concat([x1, x2], axis=0) - x2_arr = np.random.uniform(-1, 1, x_shape).astype(dtype) - x1_arr = np.random.uniform(-1, 1, x_shape).astype(dtype) + input_NHWC.persistable = True + y = layers.pool2d( + input_NHWC, pool_size=2, pool_type="avg", data_format="NHWC") + x_arr = np.random.uniform(-1, 1, [2, 5, 5, 3]).astype(np.float32) gradient_checker.double_grad_check( - [x1, x2], out, x_init=[x1_arr, x2_arr], place=place) + [input_NHWC], y, x_init=x_arr, place=place, eps=0.05) def test_grad(self): places = [fluid.CPUPlace()] @@ -381,7 +389,7 @@ def test_grad(self): self.func(p) -class TestAvgPool2DDoubleGradCheck(unittest.TestCase): +class TestAvgPool2DDoubleGradCheckCase3(unittest.TestCase): @prog_scope() def func(self, place): input_NCHW = fluid.layers.data( @@ -391,7 +399,32 @@ def func(self, place): dtype="float32") input_NCHW.persistable = True - y = layers.pool2d(input_NCHW, pool_size=2, pool_type="avg") + y = layers.pool2d( + input_NCHW, pool_size=2, pool_type="avg", pool_padding=[1, 1]) + x_arr = np.random.uniform(-1, 1, [2, 3, 5, 5]).astype(np.float32) + + gradient_checker.double_grad_check( + [input_NCHW], y, x_init=x_arr, place=place, eps=0.05) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestAvgPool2DDoubleGradCheckCase4(unittest.TestCase): + @prog_scope() + def func(self, place): + input_NCHW = fluid.layers.data( + name="input_NCHW", + shape=[2, 3, 5, 5], + append_batch_size=False, + dtype="float32") + + input_NCHW.persistable = True + y = layers.pool2d(input_NCHW, pool_size=[4, 4], pool_type="avg") x_arr = np.random.uniform(-1, 1, [2, 3, 5, 5]).astype(np.float32) gradient_checker.double_grad_check( From 782ef3f4e78bd216ba8db3c58e96750ab15c6c47 Mon Sep 17 00:00:00 2001 From: weishengyu Date: Thu, 2 Sep 2021 13:00:53 +0000 Subject: [PATCH 6/6] dbg --- .../fluid/tests/unittests/test_nn_grad.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_nn_grad.py b/python/paddle/fluid/tests/unittests/test_nn_grad.py index c9950386383ec..722926b0d77f7 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_nn_grad.py @@ -340,6 +340,47 @@ def test_grad(self): self.func(p) +class TestConstantPadDoubleGradCheckCase1(TestConstantPadDoubleGradCheck): + @prog_scope() + def func(self, place): + x_shape = [2, 3, 4, 5] + pad = [1, 0, 1, 0, 1, 0, 1, 0] + dtype = np.float64 + + x = layers.data('x', x_shape, False, dtype) + x.persistable = True + out = paddle.nn.functional.pad(x, pad) + x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype) + + gradient_checker.double_grad_check([x], out, x_init=x_arr, place=place) + + +class TestConcatDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + x_shape = [2, 3, 4, 5] + pad = [1, 1, 1, 1] + dtype = np.float64 + + x1 = layers.data('x', x_shape, False, dtype) + x2 = layers.data('x', x_shape, False, dtype) + x1.persistable = True + x2.persistable = True + out = paddle.concat([x1, x2], axis=0) + x2_arr = np.random.uniform(-1, 1, x_shape).astype(dtype) + x1_arr = np.random.uniform(-1, 1, x_shape).astype(dtype) + + gradient_checker.double_grad_check( + [x1, x2], out, x_init=[x1_arr, x2_arr], place=place) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + class TestAvgPool2DDoubleGradCheckCase1(unittest.TestCase): @prog_scope() def func(self, place):