From 3679ab1978f01e87b473b1e94a6d990155995d7a Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 21 Feb 2022 02:52:55 +0000 Subject: [PATCH 1/3] migrate huber_loss into phi --- paddle/fluid/operators/huber_loss_op.cc | 10 +- paddle/fluid/operators/huber_loss_op.cu | 24 ---- paddle/fluid/operators/huber_loss_op.h | 123 ------------------ paddle/fluid/operators/huber_loss_op_npu.cc | 2 +- paddle/fluid/operators/huber_loss_op_xpu.cc | 3 +- .../kernels/cpu/huber_loss_grad_kernel.cc | 25 ++++ paddle/pten/kernels/cpu/huber_loss_kernel.cc | 21 +++ .../kernels/gpu/huber_loss_grad_kernel.cu | 25 ++++ paddle/pten/kernels/gpu/huber_loss_kernel.cu | 21 +++ paddle/pten/kernels/huber_loss_grad_kernel.h | 30 +++++ paddle/pten/kernels/huber_loss_kernel.h | 30 +++++ .../impl/huber_loss_grad_kernel_impl.h | 75 +++++++++++ .../kernels/impl/huber_loss_kernel_impl.h | 61 +++++++++ paddle/pten/ops/compat/huber_loss_sig.cc | 36 +++++ 14 files changed, 328 insertions(+), 158 deletions(-) delete mode 100644 paddle/fluid/operators/huber_loss_op.cu delete mode 100644 paddle/fluid/operators/huber_loss_op.h create mode 100644 paddle/pten/kernels/cpu/huber_loss_grad_kernel.cc create mode 100644 paddle/pten/kernels/cpu/huber_loss_kernel.cc create mode 100644 paddle/pten/kernels/gpu/huber_loss_grad_kernel.cu create mode 100644 paddle/pten/kernels/gpu/huber_loss_kernel.cu create mode 100644 paddle/pten/kernels/huber_loss_grad_kernel.h create mode 100644 paddle/pten/kernels/huber_loss_kernel.h create mode 100644 paddle/pten/kernels/impl/huber_loss_grad_kernel_impl.h create mode 100644 paddle/pten/kernels/impl/huber_loss_kernel_impl.h create mode 100644 paddle/pten/ops/compat/huber_loss_sig.cc diff --git a/paddle/fluid/operators/huber_loss_op.cc b/paddle/fluid/operators/huber_loss_op.cc index 85c686f96120a..a4b1f77d80947 100644 --- a/paddle/fluid/operators/huber_loss_op.cc +++ b/paddle/fluid/operators/huber_loss_op.cc @@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/huber_loss_op.h" #include #include #include +#include "paddle/fluid/framework/op_registry.h" + namespace paddle { namespace operators { @@ -143,10 +144,3 @@ REGISTER_OPERATOR(huber_loss, ops::HuberLossOp, ops::HuberLossOpMaker, ops::HuberLossGradOpMaker, ops::HuberLossGradOpMaker); REGISTER_OPERATOR(huber_loss_grad, ops::HuberLossGradOp); -REGISTER_OP_CPU_KERNEL( - huber_loss, ops::HuberLossKernel, - ops::HuberLossKernel); -REGISTER_OP_CPU_KERNEL( - huber_loss_grad, - ops::HuberLossGradKernel, - ops::HuberLossGradKernel); diff --git a/paddle/fluid/operators/huber_loss_op.cu b/paddle/fluid/operators/huber_loss_op.cu deleted file mode 100644 index 4ce6856a7eade..0000000000000 --- a/paddle/fluid/operators/huber_loss_op.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#include "paddle/fluid/operators/huber_loss_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - huber_loss, - ops::HuberLossKernel, - ops::HuberLossKernel); -REGISTER_OP_CUDA_KERNEL( - huber_loss_grad, - ops::HuberLossGradKernel, - ops::HuberLossGradKernel); diff --git a/paddle/fluid/operators/huber_loss_op.h b/paddle/fluid/operators/huber_loss_op.h deleted file mode 100644 index fbfed71e1ecd4..0000000000000 --- a/paddle/fluid/operators/huber_loss_op.h +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/pten/core/hostdevice.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -template -using EigenVector = framework::EigenVector; - -template -struct HuberLossForward { - HOSTDEVICE HuberLossForward(const T& delta) : delta(delta) {} - - HOSTDEVICE T operator()(const T& val) const { - T abs_val = std::abs(val); - if (abs_val <= delta) { - return static_cast(0.5) * val * val; - } else { - return delta * (abs_val - static_cast(0.5) * delta); - } - } - - T delta; -}; - -template -class HuberLossKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in0 = context.Input("X"); - auto* in1 = context.Input("Y"); - auto* out0 = context.Output("Residual"); - auto* out1 = context.Output("Out"); - auto delta = static_cast(context.Attr("delta")); - auto& place = - *context.template device_context().eigen_device(); - - auto x = EigenVector::Flatten(*in0); - auto y = EigenVector::Flatten(*in1); - out0->mutable_data(context.GetPlace()); - auto residual = EigenVector::Flatten(*out0); - residual.device(place) = y - x; - out1->mutable_data(context.GetPlace()); - auto loss = EigenVector::Flatten(*out1); - loss.device(place) = residual.unaryExpr(HuberLossForward(delta)); - } -}; - -template -struct HuberLossBackward { - HOSTDEVICE HuberLossBackward(const T& delta, T sign) - : sign(sign), delta(delta) {} - - HOSTDEVICE T operator()(const T& val) const { - T abs_val = std::abs(val); - if (abs_val <= delta) { - return sign * val; - } else { - if (val > 0) { - return sign * delta; - } else { - return -1 * sign * delta; - } - } - } - - T sign; - T delta; -}; - -template -class HuberLossGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in0 = context.Input("Residual"); - auto* in1 = context.Input(framework::GradVarName("Out")); - auto* out0 = context.Output(framework::GradVarName("X")); - auto* out1 = context.Output(framework::GradVarName("Y")); - auto delta = static_cast(context.Attr("delta")); - auto& place = - *context.template device_context().eigen_device(); - - auto residual = EigenVector::Flatten(*in0); - auto out_grad = EigenVector::Flatten(*in1); - - if (out0) { - out0->mutable_data(context.GetPlace()); - auto x_grad = EigenVector::Flatten(*out0); - x_grad.device(place) = - residual.unaryExpr(HuberLossBackward(delta, -1.0)); - x_grad.device(place) = out_grad * x_grad; - } - - if (out1) { - out1->mutable_data(context.GetPlace()); - auto y_grad = EigenVector::Flatten(*out1); - y_grad.device(place) = - residual.unaryExpr(HuberLossBackward(delta, 1.0)); - y_grad.device(place) = out_grad * y_grad; - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/huber_loss_op_npu.cc b/paddle/fluid/operators/huber_loss_op_npu.cc index 19ced131c00a2..6fc6960d3db56 100644 --- a/paddle/fluid/operators/huber_loss_op_npu.cc +++ b/paddle/fluid/operators/huber_loss_op_npu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/huber_loss_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/huber_loss_op_xpu.cc b/paddle/fluid/operators/huber_loss_op_xpu.cc index 767ce542736e8..ccddec2779515 100644 --- a/paddle/fluid/operators/huber_loss_op_xpu.cc +++ b/paddle/fluid/operators/huber_loss_op_xpu.cc @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #ifdef PADDLE_WITH_XPU - -#include "paddle/fluid/operators/huber_loss_op.h" +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { diff --git a/paddle/pten/kernels/cpu/huber_loss_grad_kernel.cc b/paddle/pten/kernels/cpu/huber_loss_grad_kernel.cc new file mode 100644 index 0000000000000..0aaf860f4c099 --- /dev/null +++ b/paddle/pten/kernels/cpu/huber_loss_grad_kernel.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/huber_loss_grad_kernel.h" +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/impl/huber_loss_grad_kernel_impl.h" + +PT_REGISTER_KERNEL(huber_loss_grad, + CPU, + ALL_LAYOUT, + pten::HuberLossGradKernel, + float, + double) {} diff --git a/paddle/pten/kernels/cpu/huber_loss_kernel.cc b/paddle/pten/kernels/cpu/huber_loss_kernel.cc new file mode 100644 index 0000000000000..a620db12f31d5 --- /dev/null +++ b/paddle/pten/kernels/cpu/huber_loss_kernel.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/huber_loss_kernel.h" +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/impl/huber_loss_kernel_impl.h" + +PT_REGISTER_KERNEL( + huber_loss, CPU, ALL_LAYOUT, pten::HuberLossKernel, float, double) {} diff --git a/paddle/pten/kernels/gpu/huber_loss_grad_kernel.cu b/paddle/pten/kernels/gpu/huber_loss_grad_kernel.cu new file mode 100644 index 0000000000000..951d6a73f0176 --- /dev/null +++ b/paddle/pten/kernels/gpu/huber_loss_grad_kernel.cu @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/huber_loss_grad_kernel.h" +#include "paddle/pten/kernels/impl/huber_loss_grad_kernel_impl.h" + +PT_REGISTER_KERNEL(huber_loss_grad, + GPU, + ALL_LAYOUT, + pten::HuberLossGradKernel, + float, + double) {} diff --git a/paddle/pten/kernels/gpu/huber_loss_kernel.cu b/paddle/pten/kernels/gpu/huber_loss_kernel.cu new file mode 100644 index 0000000000000..ef5361d73f839 --- /dev/null +++ b/paddle/pten/kernels/gpu/huber_loss_kernel.cu @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/huber_loss_kernel.h" +#include "paddle/pten/kernels/impl/huber_loss_kernel_impl.h" + +PT_REGISTER_KERNEL( + huber_loss, GPU, ALL_LAYOUT, pten::HuberLossKernel, float, double) {} diff --git a/paddle/pten/kernels/huber_loss_grad_kernel.h b/paddle/pten/kernels/huber_loss_grad_kernel.h new file mode 100644 index 0000000000000..ec2046e2b06cb --- /dev/null +++ b/paddle/pten/kernels/huber_loss_grad_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/device_context.h" + +namespace pten { + +template +void HuberLossGradKernel(const Context& dev_ctx, + const DenseTensor& residual, + const DenseTensor& out_grad, + float delta, + DenseTensor* input_grad, + DenseTensor* label_grad); + +} // namespace pten diff --git a/paddle/pten/kernels/huber_loss_kernel.h b/paddle/pten/kernels/huber_loss_kernel.h new file mode 100644 index 0000000000000..efff64fda42e7 --- /dev/null +++ b/paddle/pten/kernels/huber_loss_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/device_context.h" + +namespace pten { + +template +void HuberLossKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& label, + float delta, + DenseTensor* out, + DenseTensor* residual); + +} // namespace pten diff --git a/paddle/pten/kernels/impl/huber_loss_grad_kernel_impl.h b/paddle/pten/kernels/impl/huber_loss_grad_kernel_impl.h new file mode 100644 index 0000000000000..fde7754890424 --- /dev/null +++ b/paddle/pten/kernels/impl/huber_loss_grad_kernel_impl.h @@ -0,0 +1,75 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/kernels/funcs/eigen/common.h" +#include "paddle/pten/kernels/funcs/eigen/eigen_function.h" +#include "paddle/pten/kernels/huber_loss_grad_kernel.h" + +namespace pten { + +template +struct HuberLossBackward { + HOSTDEVICE HuberLossBackward(const T& delta, T sign) + : sign(sign), delta(delta) {} + + HOSTDEVICE T operator()(const T& val) const { + T abs_val = std::abs(val); + if (abs_val <= delta) { + return sign * val; + } else { + if (val > 0) { + return sign * delta; + } else { + return -1 * sign * delta; + } + } + } + + T sign; + T delta; +}; + +template +void HuberLossGradKernel(const Context& dev_ctx, + const DenseTensor& residual, + const DenseTensor& out_grad, + float delta, + DenseTensor* input_grad, + DenseTensor* label_grad) { + T delta_ = static_cast(delta); + auto& place = *dev_ctx.eigen_device(); + + auto eigen_residual = EigenVector::Flatten(residual); + auto eigen_out_grad = EigenVector::Flatten(out_grad); + + if (input_grad) { + dev_ctx.template Alloc(input_grad); + auto eigen_input_grad = EigenVector::Flatten(*input_grad); + eigen_input_grad.device(place) = + eigen_residual.unaryExpr(HuberLossBackward(delta_, -1.0)); + eigen_input_grad.device(place) = eigen_out_grad * eigen_input_grad; + } + + if (label_grad) { + dev_ctx.template Alloc(label_grad); + auto eigen_label_grad = EigenVector::Flatten(*label_grad); + eigen_label_grad.device(place) = + eigen_residual.unaryExpr(HuberLossBackward(delta_, 1.0)); + eigen_label_grad.device(place) = eigen_out_grad * eigen_label_grad; + } +} + +} // namespace pten diff --git a/paddle/pten/kernels/impl/huber_loss_kernel_impl.h b/paddle/pten/kernels/impl/huber_loss_kernel_impl.h new file mode 100644 index 0000000000000..3730d6e463624 --- /dev/null +++ b/paddle/pten/kernels/impl/huber_loss_kernel_impl.h @@ -0,0 +1,61 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/kernels/funcs/eigen/common.h" +#include "paddle/pten/kernels/funcs/eigen/eigen_function.h" +#include "paddle/pten/kernels/huber_loss_kernel.h" + +namespace pten { + +template +struct HuberLossForward { + HOSTDEVICE HuberLossForward(const T& delta) : delta(delta) {} + + HOSTDEVICE T operator()(const T& val) const { + T abs_val = std::abs(val); + if (abs_val <= delta) { + return static_cast(0.5) * val * val; + } else { + return delta * (abs_val - static_cast(0.5) * delta); + } + } + + T delta; +}; + +template +void HuberLossKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& label, + float delta, + DenseTensor* out, + DenseTensor* residual) { + T delta_ = static_cast(delta); + auto& place = *dev_ctx.eigen_device(); + + auto x = EigenVector::Flatten(input); + auto y = EigenVector::Flatten(label); + + dev_ctx.template Alloc(residual); + auto eigen_residual = EigenVector::Flatten(*residual); + eigen_residual.device(place) = y - x; + + dev_ctx.template Alloc(out); + auto loss = EigenVector::Flatten(*out); + loss.device(place) = eigen_residual.unaryExpr(HuberLossForward(delta_)); +} + +} // namespace pten diff --git a/paddle/pten/ops/compat/huber_loss_sig.cc b/paddle/pten/ops/compat/huber_loss_sig.cc new file mode 100644 index 0000000000000..88a8433c6c45f --- /dev/null +++ b/paddle/pten/ops/compat/huber_loss_sig.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature HuberLossOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "huber_loss", {"X", "Y"}, {"delta"}, {"Out", "Residual"}); +} + +KernelSignature HuberLossGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("huber_loss_grad", + {"Residual", GradVarName("Out")}, + {"delta"}, + {GradVarName("X"), GradVarName("Y")}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(huber_loss, pten::HuberLossOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(huber_loss_grad, + pten::HuberLossGradOpArgumentMapping); From 7bc2644e2e3ce31a7bc7a7d46c54ee68d80f7d2e Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 21 Feb 2022 03:27:48 +0000 Subject: [PATCH 2/3] migrate infershape --- paddle/fluid/operators/huber_loss_op.cc | 38 +++++-------------------- paddle/pten/infermeta/binary.cc | 37 ++++++++++++++++++++++++ paddle/pten/infermeta/binary.h | 7 +++++ 3 files changed, 51 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/operators/huber_loss_op.cc b/paddle/fluid/operators/huber_loss_op.cc index a4b1f77d80947..8f3b60a53d71a 100644 --- a/paddle/fluid/operators/huber_loss_op.cc +++ b/paddle/fluid/operators/huber_loss_op.cc @@ -16,7 +16,9 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/pten/infermeta/binary.h" namespace paddle { namespace operators { @@ -24,36 +26,6 @@ namespace operators { class HuberLossOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "HuberLoss"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "HuberLoss"); - - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - - PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(), - platform::errors::InvalidArgument( - "Input(input) rank and Input(label) rank should be " - "same, but received input rank(%d) != label rank(%d)", - x_dims.size(), y_dims.size())); - - bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) || - framework::contain_unknown_dim(y_dims); - if (ctx->IsRuntime() || !contain_unknown_dim) { - PADDLE_ENFORCE_EQ( - x_dims, y_dims, - platform::errors::InvalidArgument( - "The Input(input) and Input(label) should have the same " - "shape, but received input shape [%s] != label shape [%s]", - x_dims, y_dims)); - } - - auto out_dims = y_dims; - ctx->SetOutputDim("Residual", out_dims); - ctx->SetOutputDim("Out", out_dims); - ctx->ShareLoD("X", "Out"); - } }; template @@ -140,7 +112,11 @@ class HuberLossGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(huber_loss, HuberLossInferShapeFunctor, + PT_INFER_META(pten::HuberLossInferMeta)); + REGISTER_OPERATOR(huber_loss, ops::HuberLossOp, ops::HuberLossOpMaker, ops::HuberLossGradOpMaker, - ops::HuberLossGradOpMaker); + ops::HuberLossGradOpMaker, + HuberLossInferShapeFunctor); REGISTER_OPERATOR(huber_loss_grad, ops::HuberLossGradOp); diff --git a/paddle/pten/infermeta/binary.cc b/paddle/pten/infermeta/binary.cc index 02d78b5caa722..6c262e38e87e1 100644 --- a/paddle/pten/infermeta/binary.cc +++ b/paddle/pten/infermeta/binary.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/pten/infermeta/binary.h" +#include "paddle/pten/core/ddim.h" #include "paddle/pten/kernels/funcs/common_shape.h" namespace pten { @@ -188,4 +189,40 @@ void ElementwiseRawInferMeta(const MetaTensor& x, out->share_lod(x); } +void HuberLossInferMeta(const MetaTensor& input, + const MetaTensor& label, + float delta, + MetaTensor* out, + MetaTensor* residual, + MetaConfig config) { + auto input_dims = input.dims(); + auto label_dims = label.dims(); + + PADDLE_ENFORCE_EQ(input_dims.size(), + label_dims.size(), + pten::errors::InvalidArgument( + "Input(input) rank and Input(label) rank should be " + "same, but received input rank(%d) != label rank(%d)", + input_dims.size(), + label_dims.size())); + + bool contain_unknown_dim = pten::contain_unknown_dim(input_dims) || + pten::contain_unknown_dim(label_dims); + if (config.is_runtime || !contain_unknown_dim) { + PADDLE_ENFORCE_EQ( + input_dims, + label_dims, + pten::errors::InvalidArgument( + "The Input(input) and Input(label) should have the same " + "shape, but received input shape [%s] != label shape [%s]", + input_dims, + label_dims)); + } + + auto out_dims = label_dims; + residual->set_dims(out_dims); + out->set_dims(out_dims); + out->share_lod(input); +} + } // namespace pten diff --git a/paddle/pten/infermeta/binary.h b/paddle/pten/infermeta/binary.h index 992082464967c..c379ba6043e63 100644 --- a/paddle/pten/infermeta/binary.h +++ b/paddle/pten/infermeta/binary.h @@ -45,4 +45,11 @@ void ElementwiseRawInferMeta(const MetaTensor& x_meta, const MetaTensor& y_meta, int axis, MetaTensor* out); + +void HuberLossInferMeta(const MetaTensor& input_meta, + const MetaTensor& label_meta, + float delta, + MetaTensor* out, + MetaTensor* residual, + MetaConfig config = MetaConfig()); } // namespace pten From 1f825b6bb15d58c1515187c0a0915afb04800c32 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Mon, 21 Feb 2022 05:06:00 +0000 Subject: [PATCH 3/3] modify pten into phi --- .../kernels/cpu/huber_loss_grad_kernel.cc | 17 +++++++---------- .../kernels/cpu/huber_loss_kernel.cc | 10 +++++----- .../kernels/gpu/huber_loss_grad_kernel.cu | 17 +++++++---------- .../kernels/gpu/huber_loss_kernel.cu | 10 +++++----- .../kernels/huber_loss_grad_kernel.h | 8 ++++---- .../{pten => phi}/kernels/huber_loss_kernel.h | 8 ++++---- .../kernels/impl/huber_loss_grad_kernel_impl.h | 10 +++++----- .../kernels/impl/huber_loss_kernel_impl.h | 10 +++++----- .../{pten => phi}/ops/compat/huber_loss_sig.cc | 10 +++++----- 9 files changed, 47 insertions(+), 53 deletions(-) rename paddle/{pten => phi}/kernels/cpu/huber_loss_grad_kernel.cc (60%) rename paddle/{pten => phi}/kernels/cpu/huber_loss_kernel.cc (69%) rename paddle/{pten => phi}/kernels/gpu/huber_loss_grad_kernel.cu (60%) rename paddle/{pten => phi}/kernels/gpu/huber_loss_kernel.cu (69%) rename paddle/{pten => phi}/kernels/huber_loss_grad_kernel.h (88%) rename paddle/{pten => phi}/kernels/huber_loss_kernel.h (88%) rename paddle/{pten => phi}/kernels/impl/huber_loss_grad_kernel_impl.h (91%) rename paddle/{pten => phi}/kernels/impl/huber_loss_kernel_impl.h (89%) rename paddle/{pten => phi}/ops/compat/huber_loss_sig.cc (83%) diff --git a/paddle/pten/kernels/cpu/huber_loss_grad_kernel.cc b/paddle/phi/kernels/cpu/huber_loss_grad_kernel.cc similarity index 60% rename from paddle/pten/kernels/cpu/huber_loss_grad_kernel.cc rename to paddle/phi/kernels/cpu/huber_loss_grad_kernel.cc index 0aaf860f4c099..bd2349393e742 100644 --- a/paddle/pten/kernels/cpu/huber_loss_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/huber_loss_grad_kernel.cc @@ -12,14 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pten/kernels/huber_loss_grad_kernel.h" -#include "paddle/pten/backends/cpu/cpu_context.h" -#include "paddle/pten/core/kernel_registry.h" -#include "paddle/pten/kernels/impl/huber_loss_grad_kernel_impl.h" +#include "paddle/phi/kernels/huber_loss_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/huber_loss_grad_kernel_impl.h" -PT_REGISTER_KERNEL(huber_loss_grad, - CPU, - ALL_LAYOUT, - pten::HuberLossGradKernel, - float, - double) {} +PT_REGISTER_KERNEL( + huber_loss_grad, CPU, ALL_LAYOUT, phi::HuberLossGradKernel, float, double) { +} diff --git a/paddle/pten/kernels/cpu/huber_loss_kernel.cc b/paddle/phi/kernels/cpu/huber_loss_kernel.cc similarity index 69% rename from paddle/pten/kernels/cpu/huber_loss_kernel.cc rename to paddle/phi/kernels/cpu/huber_loss_kernel.cc index a620db12f31d5..dfdab16bc85e3 100644 --- a/paddle/pten/kernels/cpu/huber_loss_kernel.cc +++ b/paddle/phi/kernels/cpu/huber_loss_kernel.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pten/kernels/huber_loss_kernel.h" -#include "paddle/pten/backends/cpu/cpu_context.h" -#include "paddle/pten/core/kernel_registry.h" -#include "paddle/pten/kernels/impl/huber_loss_kernel_impl.h" +#include "paddle/phi/kernels/huber_loss_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/huber_loss_kernel_impl.h" PT_REGISTER_KERNEL( - huber_loss, CPU, ALL_LAYOUT, pten::HuberLossKernel, float, double) {} + huber_loss, CPU, ALL_LAYOUT, phi::HuberLossKernel, float, double) {} diff --git a/paddle/pten/kernels/gpu/huber_loss_grad_kernel.cu b/paddle/phi/kernels/gpu/huber_loss_grad_kernel.cu similarity index 60% rename from paddle/pten/kernels/gpu/huber_loss_grad_kernel.cu rename to paddle/phi/kernels/gpu/huber_loss_grad_kernel.cu index 951d6a73f0176..5e1e000a38d95 100644 --- a/paddle/pten/kernels/gpu/huber_loss_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/huber_loss_grad_kernel.cu @@ -12,14 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pten/backends/gpu/gpu_context.h" -#include "paddle/pten/core/kernel_registry.h" -#include "paddle/pten/kernels/huber_loss_grad_kernel.h" -#include "paddle/pten/kernels/impl/huber_loss_grad_kernel_impl.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/huber_loss_grad_kernel.h" +#include "paddle/phi/kernels/impl/huber_loss_grad_kernel_impl.h" -PT_REGISTER_KERNEL(huber_loss_grad, - GPU, - ALL_LAYOUT, - pten::HuberLossGradKernel, - float, - double) {} +PT_REGISTER_KERNEL( + huber_loss_grad, GPU, ALL_LAYOUT, phi::HuberLossGradKernel, float, double) { +} diff --git a/paddle/pten/kernels/gpu/huber_loss_kernel.cu b/paddle/phi/kernels/gpu/huber_loss_kernel.cu similarity index 69% rename from paddle/pten/kernels/gpu/huber_loss_kernel.cu rename to paddle/phi/kernels/gpu/huber_loss_kernel.cu index ef5361d73f839..2cca0c08a3f3b 100644 --- a/paddle/pten/kernels/gpu/huber_loss_kernel.cu +++ b/paddle/phi/kernels/gpu/huber_loss_kernel.cu @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pten/backends/gpu/gpu_context.h" -#include "paddle/pten/core/kernel_registry.h" -#include "paddle/pten/kernels/huber_loss_kernel.h" -#include "paddle/pten/kernels/impl/huber_loss_kernel_impl.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/huber_loss_kernel.h" +#include "paddle/phi/kernels/impl/huber_loss_kernel_impl.h" PT_REGISTER_KERNEL( - huber_loss, GPU, ALL_LAYOUT, pten::HuberLossKernel, float, double) {} + huber_loss, GPU, ALL_LAYOUT, phi::HuberLossKernel, float, double) {} diff --git a/paddle/pten/kernels/huber_loss_grad_kernel.h b/paddle/phi/kernels/huber_loss_grad_kernel.h similarity index 88% rename from paddle/pten/kernels/huber_loss_grad_kernel.h rename to paddle/phi/kernels/huber_loss_grad_kernel.h index ec2046e2b06cb..c6246b1553197 100644 --- a/paddle/pten/kernels/huber_loss_grad_kernel.h +++ b/paddle/phi/kernels/huber_loss_grad_kernel.h @@ -14,10 +14,10 @@ #pragma once -#include "paddle/pten/core/dense_tensor.h" -#include "paddle/pten/core/device_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" -namespace pten { +namespace phi { template void HuberLossGradKernel(const Context& dev_ctx, @@ -27,4 +27,4 @@ void HuberLossGradKernel(const Context& dev_ctx, DenseTensor* input_grad, DenseTensor* label_grad); -} // namespace pten +} // namespace phi diff --git a/paddle/pten/kernels/huber_loss_kernel.h b/paddle/phi/kernels/huber_loss_kernel.h similarity index 88% rename from paddle/pten/kernels/huber_loss_kernel.h rename to paddle/phi/kernels/huber_loss_kernel.h index efff64fda42e7..3533a9ec6ded5 100644 --- a/paddle/pten/kernels/huber_loss_kernel.h +++ b/paddle/phi/kernels/huber_loss_kernel.h @@ -14,10 +14,10 @@ #pragma once -#include "paddle/pten/core/dense_tensor.h" -#include "paddle/pten/core/device_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/device_context.h" -namespace pten { +namespace phi { template void HuberLossKernel(const Context& dev_ctx, @@ -27,4 +27,4 @@ void HuberLossKernel(const Context& dev_ctx, DenseTensor* out, DenseTensor* residual); -} // namespace pten +} // namespace phi diff --git a/paddle/pten/kernels/impl/huber_loss_grad_kernel_impl.h b/paddle/phi/kernels/impl/huber_loss_grad_kernel_impl.h similarity index 91% rename from paddle/pten/kernels/impl/huber_loss_grad_kernel_impl.h rename to paddle/phi/kernels/impl/huber_loss_grad_kernel_impl.h index fde7754890424..b93578abba2b7 100644 --- a/paddle/pten/kernels/impl/huber_loss_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/huber_loss_grad_kernel_impl.h @@ -14,11 +14,11 @@ #pragma once -#include "paddle/pten/kernels/funcs/eigen/common.h" -#include "paddle/pten/kernels/funcs/eigen/eigen_function.h" -#include "paddle/pten/kernels/huber_loss_grad_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/huber_loss_grad_kernel.h" -namespace pten { +namespace phi { template struct HuberLossBackward { @@ -72,4 +72,4 @@ void HuberLossGradKernel(const Context& dev_ctx, } } -} // namespace pten +} // namespace phi diff --git a/paddle/pten/kernels/impl/huber_loss_kernel_impl.h b/paddle/phi/kernels/impl/huber_loss_kernel_impl.h similarity index 89% rename from paddle/pten/kernels/impl/huber_loss_kernel_impl.h rename to paddle/phi/kernels/impl/huber_loss_kernel_impl.h index 3730d6e463624..7fbdc80c3829b 100644 --- a/paddle/pten/kernels/impl/huber_loss_kernel_impl.h +++ b/paddle/phi/kernels/impl/huber_loss_kernel_impl.h @@ -14,11 +14,11 @@ #pragma once -#include "paddle/pten/kernels/funcs/eigen/common.h" -#include "paddle/pten/kernels/funcs/eigen/eigen_function.h" -#include "paddle/pten/kernels/huber_loss_kernel.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/huber_loss_kernel.h" -namespace pten { +namespace phi { template struct HuberLossForward { @@ -58,4 +58,4 @@ void HuberLossKernel(const Context& dev_ctx, loss.device(place) = eigen_residual.unaryExpr(HuberLossForward(delta_)); } -} // namespace pten +} // namespace phi diff --git a/paddle/pten/ops/compat/huber_loss_sig.cc b/paddle/phi/ops/compat/huber_loss_sig.cc similarity index 83% rename from paddle/pten/ops/compat/huber_loss_sig.cc rename to paddle/phi/ops/compat/huber_loss_sig.cc index 88a8433c6c45f..6e7183ff9f281 100644 --- a/paddle/pten/ops/compat/huber_loss_sig.cc +++ b/paddle/phi/ops/compat/huber_loss_sig.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pten/core/compat/op_utils.h" +#include "paddle/phi/core/compat/op_utils.h" -namespace pten { +namespace phi { KernelSignature HuberLossOpArgumentMapping(const ArgumentMappingContext& ctx) { return KernelSignature( @@ -29,8 +29,8 @@ KernelSignature HuberLossGradOpArgumentMapping( {GradVarName("X"), GradVarName("Y")}); } -} // namespace pten +} // namespace phi -PT_REGISTER_ARG_MAPPING_FN(huber_loss, pten::HuberLossOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(huber_loss, phi::HuberLossOpArgumentMapping); PT_REGISTER_ARG_MAPPING_FN(huber_loss_grad, - pten::HuberLossGradOpArgumentMapping); + phi::HuberLossGradOpArgumentMapping);