diff --git a/paddle/fluid/operators/dist_op.cc b/paddle/fluid/operators/dist_op.cc index 3a53f1365567f..10750574c4573 100644 --- a/paddle/fluid/operators/dist_op.cc +++ b/paddle/fluid/operators/dist_op.cc @@ -12,10 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/dist_op.h" #include #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" + namespace paddle { namespace operators { @@ -121,13 +124,11 @@ class DistGradOpMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; +DELCARE_INFER_SHAPE_FUNCTOR(dist, DistInferShapeFunctor, + PT_INFER_META(phi::DistInferMeta)); + REGISTER_OPERATOR(dist, ops::DistOp, ops::DistOpMaker, ops::DistGradOpMaker, - ops::DistGradOpMaker); + ops::DistGradOpMaker, + DistInferShapeFunctor); REGISTER_OPERATOR(dist_grad, ops::DistOpGrad); -REGISTER_OP_CPU_KERNEL( - dist, ops::DistKernel, - ops::DistKernel); -REGISTER_OP_CPU_KERNEL( - dist_grad, ops::DistGradKernel, - ops::DistGradKernel) diff --git a/paddle/fluid/operators/dist_op.h b/paddle/fluid/operators/dist_op.h deleted file mode 100644 index dfd7e29a8d010..0000000000000 --- a/paddle/fluid/operators/dist_op.h +++ /dev/null @@ -1,304 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -template -using EigenTensor = framework::EigenTensor; -using framework::Tensor; - -template -static void GetBraodcastDims(const framework::DDim& x_dims, - const framework::DDim& y_dims, - Eigen::DSizes* x_bcast_dims, - Eigen::DSizes* y_bcast_dims) { - int bcast_dims_remainder = 0; - for (int i = 0; i < x_dims.size(); ++i) { - if (x_dims[i] >= y_dims[i]) { - (*x_bcast_dims)[i] = 1; - (*y_bcast_dims)[i] = x_dims[i] / y_dims[i]; - bcast_dims_remainder += x_dims[i] % y_dims[i]; - } else { - (*y_bcast_dims)[i] = 1; - (*x_bcast_dims)[i] = y_dims[i] / x_dims[i]; - bcast_dims_remainder += y_dims[i] % x_dims[i]; - } - } - PADDLE_ENFORCE_EQ(bcast_dims_remainder, 0, - platform::errors::PreconditionNotMet( - "The input tensor of Op(dist) could not be broadcast, " - "X's shape is [%s], Y's shape is [%s].", - x_dims, y_dims)); -} - -static framework::DDim GetNewDims(const framework::DDim& in_dims, int rank) { - std::vector new_dims_vec(rank); - if (in_dims.size() < rank) { - for (int i = 0; i < rank - in_dims.size(); ++i) { - new_dims_vec[i] = 1; - } - for (int i = 0; i < in_dims.size(); ++i) { - new_dims_vec[i + rank - in_dims.size()] = in_dims[i]; - } - } else { - new_dims_vec = vectorize(in_dims); - } - return phi::make_ddim(new_dims_vec); -} - -template -static void DistFunction(const framework::ExecutionContext& context) { - auto* x = context.Input("X"); - auto* y = context.Input("Y"); - auto* out = context.Output("Out"); - auto p = context.Attr("p"); - out->mutable_data(context.GetPlace()); - - auto x_dims = context.Input("X")->dims(); - auto y_dims = context.Input("Y")->dims(); - - // new dims with same size as rank, e.g. (rank=3, (4, 3) => (1, 4, 3)) - framework::DDim x_new_dims = GetNewDims(x_dims, Rank); - framework::DDim y_new_dims = GetNewDims(y_dims, Rank); - - auto x_t = EigenTensor::From(*x, x_new_dims); - auto y_t = EigenTensor::From(*y, y_new_dims); - auto out_t = EigenTensor::From(*out); - auto& place = - *context.template device_context().eigen_device(); - - Eigen::DSizes x_bcast_dims; - Eigen::DSizes y_bcast_dims; - GetBraodcastDims(x_new_dims, y_new_dims, &x_bcast_dims, &y_bcast_dims); - // p=0 means number of non-zero elements of (x-y) - // p=inf means the maximum of |x-y| - // p=-inf means the minimum of |x-y| - // otherwise, Lp-norm = pow(sum(pow(|x-y|, p)), 1/p) - if (p == 0) { - out_t.device(place) = - (x_t.broadcast(x_bcast_dims) != y_t.broadcast(y_bcast_dims)) - .template cast() - .sum(); - } else if (p == INFINITY) { - out_t.device(place) = - (x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims)) - .abs() - .maximum(); - } else if (p == -INFINITY) { - out_t.device(place) = - (x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims)) - .abs() - .minimum(); - } else { - out_t.device(place) = - (x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims)) - .abs() - .pow(p) - .sum() - .pow(1.0 / p); - } -} - -template -static void DistGradFunction(const framework::ExecutionContext& context) { - auto* x = context.Input("X"); - auto* y = context.Input("Y"); - auto* out = context.Input("Out"); - auto p = context.Attr("p"); - - auto x_grad = context.Output(framework::GradVarName("X")); - auto y_grad = context.Output(framework::GradVarName("Y")); - auto out_grad = context.Input(framework::GradVarName("Out")); - - auto x_dims = context.Input("X")->dims(); - auto y_dims = context.Input("Y")->dims(); - auto out_dims = context.Input("Out")->dims(); - - framework::DDim x_new_dims = GetNewDims(x_dims, Rank); - framework::DDim y_new_dims = GetNewDims(y_dims, Rank); - framework::DDim out_new_dims = GetNewDims(out_dims, Rank); - auto x_t = EigenTensor::From(*x, x_new_dims); - auto y_t = EigenTensor::From(*y, y_new_dims); - auto out_t = EigenTensor::From(*out, out_new_dims); - - Eigen::DSizes x_bcast_dims; - Eigen::DSizes y_bcast_dims; - Eigen::DSizes out_bcast_dims; - - GetBraodcastDims(x_new_dims, y_new_dims, &x_bcast_dims, &y_bcast_dims); - std::vector new_dims_vec(Rank); - for (int i = 0; i < Rank; ++i) { - new_dims_vec[i] = std::max(x_new_dims[i], y_new_dims[i]); - out_bcast_dims[i] = new_dims_vec[i]; - } - framework::DDim new_dims = phi::make_ddim(new_dims_vec); - - auto& place = - *context.template device_context().eigen_device(); - auto out_grad_t = EigenTensor::From(*out_grad, out_new_dims); - framework::Tensor grad; - grad.mutable_data(new_dims, context.GetPlace()); - auto grad_t = EigenTensor::From(grad); - - auto x_minux_y = x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims); - auto x_minux_y_abs = x_minux_y.abs(); - auto sign = - (x_minux_y > static_cast(0)).template cast() * static_cast(1.0) + - (x_minux_y < static_cast(0)).template cast() * static_cast(-1.0); - T epsilon = static_cast(1.0e-10f); - - // 1: Lp-norm(z), z = x-y, compute dz - if (p == 0) { - phi::funcs::SetConstant set_zero; - auto& dev_ctx = context.template device_context(); - set_zero(dev_ctx, &grad, static_cast(0)); - } else if (p == INFINITY || p == -INFINITY) { - // p=inf or -inf, Lp-norm = |z_i|, the j-th element of dz tends to 0 if - // j!=i, or equals to sign(z_i) * dout if j=i. - if (platform::is_cpu_place(context.GetPlace())) { - grad_t.device(place) = (x_minux_y_abs == out_t.broadcast(out_bcast_dims)) - .template cast() * - sign.eval() * out_grad_t.broadcast(out_bcast_dims); - } else { - grad_t.device(place) = (x_minux_y_abs == out_t.broadcast(out_bcast_dims)) - .template cast() * - sign * out_grad_t.broadcast(out_bcast_dims); - } - } else { - // dz = pow(abs(x-y)/out, p-1) * sign(x-y) * dout - if (platform::is_cpu_place(context.GetPlace())) { - grad_t.device(place) = - (x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims)) - .pow(p - 1) * - sign.eval() * out_grad_t.broadcast(out_bcast_dims); - } else { - grad_t.device(place) = - (x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims)) - .pow(p - 1) * - sign * out_grad_t.broadcast(out_bcast_dims); - } - } - - Eigen::DSizes x_reshape_dims; - Eigen::DSizes y_reshape_dims; - Eigen::DSizes reduce_dims; - for (int i = 0; i < x_new_dims.size(); ++i) { - x_reshape_dims[2 * i] = x_bcast_dims[i]; - x_reshape_dims[2 * i + 1] = x_new_dims[i]; - y_reshape_dims[2 * i] = y_bcast_dims[i]; - y_reshape_dims[2 * i + 1] = y_new_dims[i]; - reduce_dims[i] = 2 * i; - } - - // 2: if x or y is broadcasted in forward function, - // the grad need to be sum along the broadcasted dimensions - if (x_grad) { - x_grad->mutable_data(context.GetPlace()); - auto x_grad_t = EigenTensor::From(*x_grad, x_new_dims); - x_grad_t.device(place) = grad_t.reshape(x_reshape_dims) - .sum(reduce_dims) - .reshape(x_grad_t.dimensions()); - } - if (y_grad) { - y_grad->mutable_data(context.GetPlace()); - auto y_grad_t = EigenTensor::From(*y_grad, y_new_dims); - y_grad_t.device(place) = -grad_t.reshape(y_reshape_dims) - .sum(reduce_dims) - .reshape(y_grad_t.dimensions()); - } -} - -template -class DistKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto x_rank = context.Input("X")->dims().size(); - auto y_rank = context.Input("Y")->dims().size(); - auto rank = std::max(x_rank, y_rank); - PADDLE_ENFORCE_LE(rank, 6, - platform::errors::Unimplemented( - "Op(dist) only support tensors with no more than 6 " - "dimensions, but X's rank is %d, Y's rank is %d.", - x_rank, y_rank)); - switch (rank) { - case 1: - DistFunction(context); - break; - case 2: - DistFunction(context); - break; - case 3: - DistFunction(context); - break; - case 4: - DistFunction(context); - break; - case 5: - DistFunction(context); - break; - case 6: - DistFunction(context); - break; - } - } -}; - -template -class DistGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto x_rank = context.Input("X")->dims().size(); - auto y_rank = context.Input("Y")->dims().size(); - auto rank = std::max(x_rank, y_rank); - PADDLE_ENFORCE_LE(rank, 6, - platform::errors::Unimplemented( - "Op(dist) only support tensors with no more than 6 " - "dimensions, but X's rank is %d, Y's rank is %d.", - x_rank, y_rank)); - switch (rank) { - case 1: - DistGradFunction(context); - break; - case 2: - DistGradFunction(context); - break; - case 3: - DistGradFunction(context); - break; - case 4: - DistGradFunction(context); - break; - case 5: - DistGradFunction(context); - break; - case 6: - DistGradFunction(context); - break; - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 745ddffabbe33..c3c159dfc0a69 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -397,6 +397,29 @@ void BCELossInferMeta(const MetaTensor& input, out->share_lod(input); } +void DistInferMeta(const MetaTensor& x, + const MetaTensor& y, + float p, + MetaTensor* out) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + + PADDLE_ENFORCE_NE(phi::product(x_dims), + 0, + phi::errors::InvalidArgument( + "The Input(X) has not been initialized properly. The " + "shape of Input(X) = [%s].", + x_dims)); + PADDLE_ENFORCE_NE(phi::product(y_dims), + 0, + phi::errors::InvalidArgument( + "The Input(Y) has not been initialized properly. The " + "shape of Input(Y) = [%s].", + y_dims)); + out->set_dims({1}); + out->set_dtype(x.dtype()); +} + void GatherNdInferMeta(const MetaTensor& x, const MetaTensor& index, MetaTensor* out) { diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 2ec744636988f..11db19fc410d0 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -78,6 +78,11 @@ void BCELossInferMeta(const MetaTensor& input, MetaTensor* out, MetaConfig config = MetaConfig()); +void DistInferMeta(const MetaTensor& x, + const MetaTensor& y, + float p, + MetaTensor* out); + void GatherNdInferMeta(const MetaTensor& x, const MetaTensor& index, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/dist_grad_kernel.cc b/paddle/phi/kernels/cpu/dist_grad_kernel.cc new file mode 100644 index 0000000000000..2b7f8f98f9473 --- /dev/null +++ b/paddle/phi/kernels/cpu/dist_grad_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/dist_grad_kernel.h" +#include "paddle/phi/kernels/impl/dist_grad_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL( + dist_grad, CPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/dist_kernel.cc b/paddle/phi/kernels/cpu/dist_kernel.cc new file mode 100644 index 0000000000000..ccf3d4be83230 --- /dev/null +++ b/paddle/phi/kernels/cpu/dist_kernel.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/dist_kernel.h" +#include "paddle/phi/kernels/impl/dist_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(dist, CPU, ALL_LAYOUT, phi::DistKernel, float, double) {} diff --git a/paddle/phi/kernels/dist_grad_kernel.h b/paddle/phi/kernels/dist_grad_kernel.h new file mode 100644 index 0000000000000..1f8d7ff21f2fe --- /dev/null +++ b/paddle/phi/kernels/dist_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void DistGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& out_grad, + float p, + DenseTensor* x_grad, + DenseTensor* y_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/dist_kernel.h b/paddle/phi/kernels/dist_kernel.h new file mode 100644 index 0000000000000..6cb3d6e0e8bef --- /dev/null +++ b/paddle/phi/kernels/dist_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void DistKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + float p, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/dist_grad_kernel.cu b/paddle/phi/kernels/gpu/dist_grad_kernel.cu new file mode 100644 index 0000000000000..c458f8cce3e0a --- /dev/null +++ b/paddle/phi/kernels/gpu/dist_grad_kernel.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/dist_grad_kernel.h" +#include "paddle/phi/kernels/impl/dist_grad_kernel_impl.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#ifdef PADDLE_WITH_HIP +PD_REGISTER_KERNEL(dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float) {} +#else +PD_REGISTER_KERNEL( + dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {} +#endif diff --git a/paddle/fluid/operators/dist_op.cu b/paddle/phi/kernels/gpu/dist_kernel.cu similarity index 51% rename from paddle/fluid/operators/dist_op.cu rename to paddle/phi/kernels/gpu/dist_kernel.cu index 90674969e283f..87e75e02754a8 100644 --- a/paddle/fluid/operators/dist_op.cu +++ b/paddle/phi/kernels/gpu/dist_kernel.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,21 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/dist_op.h" +#include "paddle/phi/kernels/dist_kernel.h" +#include "paddle/phi/kernels/impl/dist_kernel_impl.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" -namespace ops = paddle::operators; #ifdef PADDLE_WITH_HIP // Eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h:922 // do not support double in HIPCC platform (Eigen3 to be fixed) -REGISTER_OP_CUDA_KERNEL( - dist, ops::DistKernel); -REGISTER_OP_CUDA_KERNEL( - dist_grad, ops::DistGradKernel); +PD_REGISTER_KERNEL(dist, GPU, ALL_LAYOUT, phi::DistKernel, float) {} #else -REGISTER_OP_CUDA_KERNEL( - dist, ops::DistKernel, - ops::DistKernel); -REGISTER_OP_CUDA_KERNEL( - dist_grad, ops::DistGradKernel, - ops::DistGradKernel); +PD_REGISTER_KERNEL(dist, GPU, ALL_LAYOUT, phi::DistKernel, float, double) {} #endif diff --git a/paddle/phi/kernels/impl/dist_grad_kernel_impl.h b/paddle/phi/kernels/impl/dist_grad_kernel_impl.h new file mode 100644 index 0000000000000..fc118a832dc9f --- /dev/null +++ b/paddle/phi/kernels/impl/dist_grad_kernel_impl.h @@ -0,0 +1,223 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +using ETensor = phi::EigenTensor; + +template +static void GetBraodcastDims(const phi::DDim& x_dims, + const phi::DDim& y_dims, + Eigen::DSizes* x_bcast_dims, + Eigen::DSizes* y_bcast_dims) { + int bcast_dims_remainder = 0; + for (int i = 0; i < x_dims.size(); ++i) { + if (x_dims[i] >= y_dims[i]) { + (*x_bcast_dims)[i] = 1; + (*y_bcast_dims)[i] = x_dims[i] / y_dims[i]; + bcast_dims_remainder += x_dims[i] % y_dims[i]; + } else { + (*y_bcast_dims)[i] = 1; + (*x_bcast_dims)[i] = y_dims[i] / x_dims[i]; + bcast_dims_remainder += y_dims[i] % x_dims[i]; + } + } + PADDLE_ENFORCE_EQ(bcast_dims_remainder, + 0, + phi::errors::PreconditionNotMet( + "The input tensor of Op(dist) could not be broadcast, " + "X's shape is [%s], Y's shape is [%s].", + x_dims, + y_dims)); +} + +static phi::DDim GetNewDims(const phi::DDim& in_dims, int rank) { + std::vector new_dims_vec(rank); + if (in_dims.size() < rank) { + for (int i = 0; i < rank - in_dims.size(); ++i) { + new_dims_vec[i] = 1; + } + for (int i = 0; i < in_dims.size(); ++i) { + new_dims_vec[i + rank - in_dims.size()] = in_dims[i]; + } + } else { + new_dims_vec = vectorize(in_dims); + } + return phi::make_ddim(new_dims_vec); +} + +template +static void DistGradFunction(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& out_grad, + float p, + DenseTensor* x_grad, + DenseTensor* y_grad) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + auto out_dims = out.dims(); + + phi::DDim x_new_dims = GetNewDims(x_dims, Rank); + phi::DDim y_new_dims = GetNewDims(y_dims, Rank); + phi::DDim out_new_dims = GetNewDims(out_dims, Rank); + auto x_t = ETensor::From(x, x_new_dims); + auto y_t = ETensor::From(y, y_new_dims); + auto out_t = ETensor::From(out, out_new_dims); + + Eigen::DSizes x_bcast_dims; + Eigen::DSizes y_bcast_dims; + Eigen::DSizes out_bcast_dims; + + GetBraodcastDims(x_new_dims, y_new_dims, &x_bcast_dims, &y_bcast_dims); + std::vector new_dims_vec(Rank); + for (int i = 0; i < Rank; ++i) { + new_dims_vec[i] = std::max(x_new_dims[i], y_new_dims[i]); + out_bcast_dims[i] = new_dims_vec[i]; + } + phi::DDim new_dims = phi::make_ddim(new_dims_vec); + + auto& place = *dev_ctx.eigen_device(); + auto out_grad_t = ETensor::From(out_grad, out_new_dims); + DenseTensor grad; + grad.Resize(new_dims); + dev_ctx.template Alloc(&grad); + auto grad_t = ETensor::From(grad); + + auto x_minux_y = x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims); + auto x_minux_y_abs = x_minux_y.abs(); + auto sign = + (x_minux_y > static_cast(0)).template cast() * static_cast(1.0) + + (x_minux_y < static_cast(0)).template cast() * static_cast(-1.0); + T epsilon = static_cast(1.0e-10f); + + // 1: Lp-norm(z), z = x-y, compute dz + if (p == 0) { + phi::funcs::SetConstant set_zero; + set_zero(dev_ctx, &grad, static_cast(0)); + } else if (p == INFINITY || p == -INFINITY) { + // p=inf or -inf, Lp-norm = |z_i|, the j-th element of dz tends to 0 if + // j!=i, or equals to sign(z_i) * dout if j=i. + if (paddle::platform::is_cpu_place(dev_ctx.GetPlace())) { + grad_t.device(place) = (x_minux_y_abs == out_t.broadcast(out_bcast_dims)) + .template cast() * + sign.eval() * out_grad_t.broadcast(out_bcast_dims); + } else { + grad_t.device(place) = (x_minux_y_abs == out_t.broadcast(out_bcast_dims)) + .template cast() * + sign * out_grad_t.broadcast(out_bcast_dims); + } + } else { + // dz = pow(abs(x-y)/out, p-1) * sign(x-y) * dout + if (paddle::platform::is_cpu_place(dev_ctx.GetPlace())) { + grad_t.device(place) = + (x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims)) + .pow(p - 1) * + sign.eval() * out_grad_t.broadcast(out_bcast_dims); + } else { + grad_t.device(place) = + (x_minux_y_abs / (out_t + epsilon).broadcast(out_bcast_dims)) + .pow(p - 1) * + sign * out_grad_t.broadcast(out_bcast_dims); + } + } + + Eigen::DSizes x_reshape_dims; + Eigen::DSizes y_reshape_dims; + Eigen::DSizes reduce_dims; + for (int i = 0; i < x_new_dims.size(); ++i) { + x_reshape_dims[2 * i] = x_bcast_dims[i]; + x_reshape_dims[2 * i + 1] = x_new_dims[i]; + y_reshape_dims[2 * i] = y_bcast_dims[i]; + y_reshape_dims[2 * i + 1] = y_new_dims[i]; + reduce_dims[i] = 2 * i; + } + + // 2: if x or y is broadcasted in forward function, + // the grad need to be sum along the broadcasted dimensions + if (x_grad) { + dev_ctx.template Alloc(x_grad); + auto x_grad_t = ETensor::From(*x_grad, x_new_dims); + x_grad_t.device(place) = grad_t.reshape(x_reshape_dims) + .sum(reduce_dims) + .reshape(x_grad_t.dimensions()); + } + if (y_grad) { + dev_ctx.template Alloc(y_grad); + auto y_grad_t = ETensor::From(*y_grad, y_new_dims); + y_grad_t.device(place) = -grad_t.reshape(y_reshape_dims) + .sum(reduce_dims) + .reshape(y_grad_t.dimensions()); + } +} + +template +void DistGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& out_grad, + float p, + DenseTensor* x_grad, + DenseTensor* y_grad) { + auto x_rank = x.dims().size(); + auto y_rank = y.dims().size(); + auto rank = std::max(x_rank, y_rank); + PADDLE_ENFORCE_LE(rank, + 6, + phi::errors::Unimplemented( + "Op(dist) only support tensors with no more than 6 " + "dimensions, but X's rank is %d, Y's rank is %d.", + x_rank, + y_rank)); + switch (rank) { + case 1: + DistGradFunction( + dev_ctx, x, y, out, out_grad, p, x_grad, y_grad); + break; + case 2: + DistGradFunction( + dev_ctx, x, y, out, out_grad, p, x_grad, y_grad); + break; + case 3: + DistGradFunction( + dev_ctx, x, y, out, out_grad, p, x_grad, y_grad); + break; + case 4: + DistGradFunction( + dev_ctx, x, y, out, out_grad, p, x_grad, y_grad); + break; + case 5: + DistGradFunction( + dev_ctx, x, y, out, out_grad, p, x_grad, y_grad); + break; + case 6: + DistGradFunction( + dev_ctx, x, y, out, out_grad, p, x_grad, y_grad); + break; + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/dist_kernel_impl.h b/paddle/phi/kernels/impl/dist_kernel_impl.h new file mode 100644 index 0000000000000..397fc1b922433 --- /dev/null +++ b/paddle/phi/kernels/impl/dist_kernel_impl.h @@ -0,0 +1,164 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +using ETensor = phi::EigenTensor; + +template +static void GetBraodcastDims(const phi::DDim& x_dims, + const phi::DDim& y_dims, + Eigen::DSizes* x_bcast_dims, + Eigen::DSizes* y_bcast_dims) { + int bcast_dims_remainder = 0; + for (int i = 0; i < x_dims.size(); ++i) { + if (x_dims[i] >= y_dims[i]) { + (*x_bcast_dims)[i] = 1; + (*y_bcast_dims)[i] = x_dims[i] / y_dims[i]; + bcast_dims_remainder += x_dims[i] % y_dims[i]; + } else { + (*y_bcast_dims)[i] = 1; + (*x_bcast_dims)[i] = y_dims[i] / x_dims[i]; + bcast_dims_remainder += y_dims[i] % x_dims[i]; + } + } + PADDLE_ENFORCE_EQ(bcast_dims_remainder, + 0, + phi::errors::PreconditionNotMet( + "The input tensor of Op(dist) could not be broadcast, " + "X's shape is [%s], Y's shape is [%s].", + x_dims, + y_dims)); +} + +static phi::DDim GetNewDims(const phi::DDim& in_dims, int rank) { + std::vector new_dims_vec(rank); + if (in_dims.size() < rank) { + for (int i = 0; i < rank - in_dims.size(); ++i) { + new_dims_vec[i] = 1; + } + for (int i = 0; i < in_dims.size(); ++i) { + new_dims_vec[i + rank - in_dims.size()] = in_dims[i]; + } + } else { + new_dims_vec = vectorize(in_dims); + } + return phi::make_ddim(new_dims_vec); +} + +template +static void DistFunction(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + float p, + DenseTensor* out) { + if (out) { + dev_ctx.template Alloc(out); + } + auto x_dims = x.dims(); + auto y_dims = y.dims(); + + // new dims with same size as rank, e.g. (rank=3, (4, 3) => (1, 4, 3)) + phi::DDim x_new_dims = GetNewDims(x_dims, Rank); + phi::DDim y_new_dims = GetNewDims(y_dims, Rank); + + auto x_t = ETensor::From(x, x_new_dims); + auto y_t = ETensor::From(y, y_new_dims); + auto out_t = ETensor::From(*out); + auto& place = *dev_ctx.eigen_device(); + + Eigen::DSizes x_bcast_dims; + Eigen::DSizes y_bcast_dims; + GetBraodcastDims(x_new_dims, y_new_dims, &x_bcast_dims, &y_bcast_dims); + // p=0 means number of non-zero elements of (x-y) + // p=inf means the maximum of |x-y| + // p=-inf means the minimum of |x-y| + // otherwise, Lp-norm = pow(sum(pow(|x-y|, p)), 1/p) + if (p == 0) { + out_t.device(place) = + (x_t.broadcast(x_bcast_dims) != y_t.broadcast(y_bcast_dims)) + .template cast() + .sum(); + } else if (p == INFINITY) { + out_t.device(place) = + (x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims)) + .abs() + .maximum(); + } else if (p == -INFINITY) { + out_t.device(place) = + (x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims)) + .abs() + .minimum(); + } else { + out_t.device(place) = + (x_t.broadcast(x_bcast_dims) - y_t.broadcast(y_bcast_dims)) + .abs() + .pow(p) + .sum() + .pow(1.0 / p); + } +} + +template +void DistKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + float p, + DenseTensor* out) { + auto x_rank = x.dims().size(); + auto y_rank = y.dims().size(); + auto rank = std::max(x_rank, y_rank); + PADDLE_ENFORCE_LE(rank, + 6, + phi::errors::Unimplemented( + "Op(dist) only support tensors with no more than 6 " + "dimensions, but X's rank is %d, Y's rank is %d.", + x_rank, + y_rank)); + switch (rank) { + case 1: + DistFunction(dev_ctx, x, y, p, out); + break; + case 2: + DistFunction(dev_ctx, x, y, p, out); + break; + case 3: + DistFunction(dev_ctx, x, y, p, out); + break; + case 4: + DistFunction(dev_ctx, x, y, p, out); + break; + case 5: + DistFunction(dev_ctx, x, y, p, out); + break; + case 6: + DistFunction(dev_ctx, x, y, p, out); + break; + } +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/dist_sig.cc b/paddle/phi/ops/compat/dist_sig.cc new file mode 100644 index 0000000000000..18a30b9b84048 --- /dev/null +++ b/paddle/phi/ops/compat/dist_sig.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature DistGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("dist_grad", + {"X", "Y", "Out", GradVarName("Out")}, + {"p"}, + {GradVarName("X"), GradVarName("Y")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(dist_grad, phi::DistGradOpArgumentMapping);