diff --git a/paddle/fluid/framework/new_executor/standalone_executor_test.cc b/paddle/fluid/framework/new_executor/standalone_executor_test.cc index e7f18bdc88a11..d3adccff73337 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor_test.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor_test.cc @@ -81,6 +81,8 @@ PD_DECLARE_KERNEL(sum, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sum_grad, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(max_raw, GPU, ALL_LAYOUT); PD_DECLARE_KERNEL(sgd, GPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(slice, GPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(slice_grad, GPU, ALL_LAYOUT); DECLARE_double(eager_delete_tensor_gb); diff --git a/paddle/fluid/operators/lu_op.h b/paddle/fluid/operators/lu_op.h index 2414ae68438fd..939558c710a3a 100644 --- a/paddle/fluid/operators/lu_op.h +++ b/paddle/fluid/operators/lu_op.h @@ -42,9 +42,12 @@ void SetValueCompute(const framework::ExecutionContext& ctx, auto dtype = framework::TransToProtoVarType(in->dtype()); auto in_dims = in->dims(); - CheckAndUpdateSliceAttrs(in_dims, axes, starts, ends, &steps); - auto slice_dims = GetSliceDims(in_dims, axes, *starts, *ends, &steps); - auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes); + phi::funcs::CheckAndUpdateSliceAttrs(in_dims, axes, starts, ends, + &steps); + auto slice_dims = + phi::funcs::GetSliceDims(in_dims, axes, *starts, *ends, &steps); + auto decrease_slice_dims = + phi::funcs::GetDecreasedDims(slice_dims, decrease_axes); auto slice_dims_for_assign = decrease_slice_dims; if (!none_axes.empty()) { @@ -282,10 +285,10 @@ void SliceCompute(const framework::ExecutionContext& ctx, } } - CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends); - slice_dims = - GetSliceDims(in_dims, axes, starts, ends, nullptr, nullptr); - out_dims = GetDecreasedDims(slice_dims, decrease_axis); + phi::funcs::CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends); + slice_dims = phi::funcs::GetSliceDims(in_dims, axes, starts, ends, + nullptr, nullptr); + out_dims = phi::funcs::GetDecreasedDims(slice_dims, decrease_axis); // 2.2 Get output auto offsets = Eigen::DSizes(); diff --git a/paddle/fluid/operators/set_value_op.h b/paddle/fluid/operators/set_value_op.h index 4696907f32e6d..580e62094fa0b 100644 --- a/paddle/fluid/operators/set_value_op.h +++ b/paddle/fluid/operators/set_value_op.h @@ -22,9 +22,11 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/assign_value_op.h" -#include "paddle/fluid/operators/slice_utils.h" +#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/utils.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/kernels/funcs/slice_utils.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/set_value_op_npu.cc b/paddle/fluid/operators/set_value_op_npu.cc index 46d64333b608b..31ed820ebe07e 100644 --- a/paddle/fluid/operators/set_value_op_npu.cc +++ b/paddle/fluid/operators/set_value_op_npu.cc @@ -15,6 +15,8 @@ limitations under the License. */ #include "paddle/fluid/operators/set_value_op.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" +#include "paddle/phi/kernels/funcs/slice_utils.h" + namespace paddle { namespace operators { @@ -51,9 +53,11 @@ class SetValueNPUKernel : public framework::OpKernel { } auto in_dims = in->dims(); - CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends, &steps); - auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, &steps); - auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes); + phi::funcs::CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends, &steps); + auto slice_dims = + phi::funcs::GetSliceDims(in_dims, axes, starts, ends, &steps); + auto decrease_slice_dims = + phi::funcs::GetDecreasedDims(slice_dims, decrease_axes); auto slice_dims_for_assign = decrease_slice_dims; if (!none_axes.empty()) { diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index 689f93593fef4..c6432d00e9de1 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include "paddle/phi/kernels/funcs/slice_utils.h" namespace paddle { namespace operators { @@ -101,15 +102,17 @@ class SliceOp : public framework::OperatorWithKernel { "The size of ends must be equal to the size of axes.")); } - CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends, nullptr, - &infer_flags); + phi::funcs::CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends, + nullptr, &infer_flags); - auto slice_dims = - GetSliceDims(in_dims, axes, starts, ends, nullptr, &infer_flags); + auto slice_dims = phi::funcs::GetSliceDims(in_dims, axes, starts, ends, + nullptr, &infer_flags); if (ctx->IsRuntime()) { - out_dims = GetDecreasedDims(slice_dims, decrease_axis, &infer_flags); + out_dims = phi::funcs::GetDecreasedDims(slice_dims, decrease_axis, + &infer_flags); } else { - out_dims = GetDecreasedDims(slice_dims, decrease_axis, nullptr); + out_dims = + phi::funcs::GetDecreasedDims(slice_dims, decrease_axis, nullptr); } ctx->SetOutputDim("Out", out_dims); diff --git a/paddle/fluid/operators/slice_op.h b/paddle/fluid/operators/slice_op.h index bd07909aa91af..a9a98b46d5eb7 100644 --- a/paddle/fluid/operators/slice_op.h +++ b/paddle/fluid/operators/slice_op.h @@ -18,7 +18,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/eigen/eigen_function.h" -#include "paddle/fluid/operators/slice_utils.h" #include "paddle/fluid/operators/utils.h" #include "paddle/phi/kernels/funcs/math_function.h" @@ -81,38 +80,6 @@ template class SliceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - const Variable* input_var = ctx.InputVar("Input"); - bool is_tensor_array = input_var->IsType(); - int rank = is_tensor_array ? 1 : ctx.Input("Input")->dims().size(); - - switch (rank) { - case 1: - SliceCompute<1>(ctx); - break; - case 2: - SliceCompute<2>(ctx); - break; - case 3: - SliceCompute<3>(ctx); - break; - case 4: - SliceCompute<4>(ctx); - break; - case 5: - SliceCompute<5>(ctx); - break; - case 6: - SliceCompute<6>(ctx); - break; - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "The rank of input should be less than 7, but received %d.", rank)); - } - } - - private: - template - void SliceCompute(const framework::ExecutionContext& ctx) const { const Variable* input_var = ctx.InputVar("Input"); Variable* out_var = ctx.OutputVar("Out"); bool input_is_array = input_var->IsType(); @@ -156,68 +123,6 @@ class SliceKernel : public framework::OpKernel { if (input_is_array) { DealTensorArray(ctx, starts, ends, out_is_array); return; - } else { - auto in = ctx.Input("Input"); - auto out = ctx.Output("Out"); - - auto in_dims = in->dims(); - auto out_dims = out->dims(); - auto slice_dims = out_dims; - - // 2.1 Infer output dims - for (size_t i = 0; i < axes.size(); ++i) { - // when start == -1 && end == start+1 - if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) { - auto ret = - std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]); - if (ret != decrease_axis.end()) { - ends[i] = in_dims[axes[i]]; - } - } - } - - CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends); - slice_dims = - GetSliceDims(in_dims, axes, starts, ends, nullptr, nullptr); - out_dims = GetDecreasedDims(slice_dims, decrease_axis); - - // 2.2 Get output - auto offsets = Eigen::DSizes(); - auto extents = Eigen::DSizes(); - - for (size_t i = 0; i < D; ++i) { - offsets[i] = 0; - extents[i] = slice_dims[i]; - } - for (size_t i = 0; i < axes.size(); ++i) { - offsets[axes[i]] = starts[i]; - } - - out->Resize(slice_dims); - out->mutable_data(ctx.GetPlace()); - - auto in_t = framework::EigenTensor::From(*in, in_dims); - auto out_t = framework::EigenTensor::From(*out, slice_dims); - auto& eigen_place = - *ctx.template device_context().eigen_device(); - - if (in->numel() <= Eigen::NumTraits::highest()) { - // similar to tf.slice: - // if element number less than INT_MAX, change the type of index to int - Eigen::DSizes offsets_32bit, extents_32bit; - for (size_t i = 0; i < D; i++) { - offsets_32bit[i] = offsets[i]; - extents_32bit[i] = extents[i]; - } - EigenSlice, T, D>::Eval( - eigen_place, framework::To32BitIndex(out_t), - framework::To32BitIndex(in_t), offsets_32bit, extents_32bit); - } else { - EigenSlice, T, D>::Eval( - eigen_place, out_t, in_t, offsets, extents); - } - - out->Resize(out_dims); } } }; @@ -226,38 +131,6 @@ template class SliceGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - const Variable* input_var = ctx.InputVar("Input"); - bool is_array = input_var->IsType(); - size_t rank = is_array ? 1 : ctx.Input("Input")->dims().size(); - - switch (rank) { - case 1: - SliceCompute<1>(ctx); - break; - case 2: - SliceCompute<2>(ctx); - break; - case 3: - SliceCompute<3>(ctx); - break; - case 4: - SliceCompute<4>(ctx); - break; - case 5: - SliceCompute<5>(ctx); - break; - case 6: - SliceCompute<6>(ctx); - break; - default: - PADDLE_THROW(platform::errors::InvalidArgument( - "The rank of input should be less than 7, but received %d.", rank)); - } - } - - private: - template - void SliceCompute(const framework::ExecutionContext& ctx) const { auto axes = ctx.Attr>("axes"); auto starts_int = ctx.Attr>("starts"); auto ends_int = ctx.Attr>("ends"); @@ -323,226 +196,9 @@ class SliceGradKernel : public framework::OpKernel { } return; } - - auto* d_out = ctx.Input(framework::GradVarName("Out")); - auto* d_input = ctx.Output(framework::GradVarName("Input")); - d_input->mutable_data(ctx.GetPlace()); - - auto out_dims = d_out->dims(); - auto in_dims = d_input->dims(); - - auto decrease_axis = ctx.Attr>("decrease_axis"); - auto decrease_size = decrease_axis.size(); - if (decrease_size > 0) { - if (decrease_size == static_cast(in_dims.size())) { - // all dims decrease - std::vector origin_out_shape(decrease_size, 1); - out_dims = phi::make_ddim(std::vector(decrease_size, 1)); - } else { - std::vector origin_out_shape(out_dims.size() + decrease_size, -1); - for (size_t i = 0; i < decrease_size; ++i) { - origin_out_shape[decrease_axis[i]] = 1; - } - - int index = 0; - for (size_t i = 0; i < origin_out_shape.size(); ++i) { - if (origin_out_shape[i] == -1) { - origin_out_shape[i] = out_dims[index]; - ++index; - } - } - - out_dims = phi::make_ddim(origin_out_shape); - } - } - - auto offsets = Eigen::array(); - auto extents = Eigen::array(); - for (size_t i = 0; i < D; ++i) { - offsets[i] = 0; - extents[i] = out_dims[i]; - } - - for (size_t i = 0; i < axes.size(); ++i) { - int axis = axes[i]; - int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i]; - start = std::max(start, static_cast(0)); - offsets[axis] = start; - } - - Eigen::array, D> paddings; - for (size_t i = 0; i < paddings.size(); ++i) { - paddings[i].first = offsets[i]; - paddings[i].second = (in_dims[i] - out_dims[i]) - offsets[i]; - } - EigenPaddingCompute(ctx, d_input, in_dims, d_out, out_dims, paddings); - } - - template - void EigenPaddingCompute( - const framework::ExecutionContext& context, Tensor* d_input, - const DDim& in_dims, const Tensor* d_out, const DDim& out_dims, - const Eigen::array, D>& paddings) const { - if (D <= 3) { - // if dimension less than 3, cannot reduce dimension - LaunchEigenPadding(context, d_input, in_dims, d_out, out_dims, paddings); - } else { // else we can reduce dimension - // count not-zero padding number, and record the dimension - int need_pad_num = 0, pad_dim = -1; - for (size_t i = 0; i < D; i++) { - if (paddings[i].first != 0 || paddings[i].second != 0) { - need_pad_num++; - pad_dim = i; - } - } - - if (need_pad_num == 1) { - // only need padding one dimension, we can reduce dimension. - // only the padding dimension is available for us. - // How to reduce dimension(5 to 3 for example): - // before(D=5): - // in_dims: [x1, x2, x3, x4, x5] - // padding.first: [0, 0, a, 0, 0] - // padding.second: [0, 0, b, 0, 0] - // | | - // V V - // after(D=3): - // reshaped_in_dims: [x1*x2, x3, x4*x5] - // reshaped_padding.first: [0, a, 0] - // reshaped_padding.second: [0, b, 0] - - if (pad_dim == D - 1) { - // only last dimension need padding, - // reshape the dimension of tensor in 2: [preceding, padding] - std::vector in_tore_shape(2, 1), out_tore_shape(2, 1); - Eigen::array, 2> reshaped_padding; - - // first dimension is the accumulate of preceding dimension - for (int i = 0; i < pad_dim; i++) { - in_tore_shape[0] *= in_dims[i]; - out_tore_shape[0] *= out_dims[i]; - } - // second dimension is the padding dimension - in_tore_shape[1] = in_dims[pad_dim]; - out_tore_shape[1] = out_dims[pad_dim]; - - // convert array from std::vector to DDim - DDim reshaped_in_dims = phi::make_ddim(in_tore_shape); - DDim reshaped_out_dims = phi::make_ddim(out_tore_shape); - - // after reshape: the first dimension do not need padding, - // set padding[0] zero - reshaped_padding[0].first = reshaped_padding[0].second = 0; - // the second dimension is the previous padding dimension - reshaped_padding[1].first = paddings[pad_dim].first; - reshaped_padding[1].second = paddings[pad_dim].second; - - LaunchEigenPadding(context, d_input, reshaped_in_dims, d_out, - reshaped_out_dims, reshaped_padding); - } else if (pad_dim == 0) { - // only first dimension need padding, - // reshape the dimension of tensor in 2: [padding, succeeding] - // similar to (D - 1) - std::vector in_tore_shape(2, 1), out_tore_shape(2, 1); - Eigen::array, 2> reshaped_padding; - - // first dimension is the padding dimension - in_tore_shape[0] = in_dims[pad_dim]; - out_tore_shape[0] = out_dims[pad_dim]; - // sencond dimension is the accumulate of succeeding dimension - for (size_t i = pad_dim + 1; i < D; i++) { - in_tore_shape[1] *= in_dims[i]; - out_tore_shape[1] *= out_dims[i]; - } - - // convert array from std::vector to DDim - DDim reshaped_in_dims = phi::make_ddim(in_tore_shape); - DDim reshaped_out_dims = phi::make_ddim(out_tore_shape); - - // after reshape: - // the first dimension is the previous padding dimension - reshaped_padding[0].first = paddings[pad_dim].first; - reshaped_padding[0].second = paddings[pad_dim].second; - // the second dimension do not need padding, set padding[1] zero - reshaped_padding[1].first = reshaped_padding[1].second = 0; - - LaunchEigenPadding(context, d_input, reshaped_in_dims, d_out, - reshaped_out_dims, reshaped_padding); - } else { - // other dimension need padding - // reshape the dimension of tensor in 3: - // [preceding, padding, succeeding] - std::vector in_tore_shape(3, 1), out_tore_shape(3, 1); - Eigen::array, 3> reshaped_padding; - - // first dimension is the accumulate of preceding dimension - for (int i = 0; i < pad_dim; i++) { - in_tore_shape[0] *= in_dims[i]; - out_tore_shape[0] *= out_dims[i]; - } - // second dimension is the padding dimension - in_tore_shape[1] = in_dims[pad_dim]; - out_tore_shape[1] = out_dims[pad_dim]; - // third dimension is the accumulate of succeeding dimension - for (size_t i = pad_dim + 1; i < D; i++) { - in_tore_shape[2] *= in_dims[i]; - out_tore_shape[2] *= out_dims[i]; - } - - // convert array from std::vector to DDim - DDim reshaped_in_dims = phi::make_ddim(in_tore_shape); - DDim reshaped_out_dims = phi::make_ddim(out_tore_shape); - - // after reshape: - // the first dimension do not need padding, set padding[0] zero - reshaped_padding[0].first = reshaped_padding[2].second = 0; - // the second dimension is the previous padding dimension - reshaped_padding[1].first = paddings[pad_dim].first; - reshaped_padding[1].second = paddings[pad_dim].second; - // the third dimension do not need padding, set padding[2] zero - reshaped_padding[2].first = reshaped_padding[2].second = 0; - - LaunchEigenPadding(context, d_input, reshaped_in_dims, d_out, - reshaped_out_dims, reshaped_padding); - } - } else { - // need padding at many dimension, cannot reduce dimension - LaunchEigenPadding(context, d_input, in_dims, d_out, out_dims, - paddings); - } - } } - template - void LaunchEigenPadding( - const framework::ExecutionContext& context, Tensor* d_input, - const DDim& in_dims, const Tensor* d_out, const DDim& out_dims, - const Eigen::array, D>& paddings) const { - auto& place = - *context.template device_context().eigen_device(); - auto d_in_t = - framework::EigenTensor::From( - *d_input, in_dims); - auto d_out_t = - framework::EigenTensor::From( - *d_out, out_dims); - - if (d_input->numel() <= Eigen::NumTraits::highest()) { - // similar to tf.pad: - // if element number less than INT_MAX, change the type of index to int - Eigen::array, D> paddings_32bit; - for (size_t i = 0; i < D; i++) { - paddings_32bit[i] = - std::make_pair(paddings[i].first, paddings[i].second); - } - EigenPad, T, D>::Eval( - place, framework::To32BitIndex(d_in_t), - framework::To32BitIndex(d_out_t), paddings_32bit, static_cast(0)); - } else { - EigenPad, T, D>::Eval( - place, d_in_t, d_out_t, paddings, static_cast(0)); - } - } + private: }; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/slice_op_npu.cc b/paddle/fluid/operators/slice_op_npu.cc index 6bc2ae3663894..0d0d9ab19df30 100644 --- a/paddle/fluid/operators/slice_op_npu.cc +++ b/paddle/fluid/operators/slice_op_npu.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/slice_op.h" + #include "paddle/fluid/platform/device/npu/npu_op_runner.h" +#include "paddle/phi/kernels/funcs/slice_utils.h" namespace paddle { namespace operators { @@ -109,10 +111,10 @@ class SliceNPUKernel : public framework::OpKernel { } } - CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends); - slice_dims = - GetSliceDims(in_dims, axes, starts, ends, nullptr, nullptr); - out_dims = GetDecreasedDims(slice_dims, decrease_axis); + phi::funcs::CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends); + slice_dims = phi::funcs::GetSliceDims(in_dims, axes, starts, ends, + nullptr, nullptr); + out_dims = phi::funcs::GetDecreasedDims(slice_dims, decrease_axis); out->Resize(out_dims); } diff --git a/paddle/phi/kernels/cpu/slice_grad_kernel.cc b/paddle/phi/kernels/cpu/slice_grad_kernel.cc new file mode 100644 index 0000000000000..5c2cb3ea80e87 --- /dev/null +++ b/paddle/phi/kernels/cpu/slice_grad_kernel.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/slice_grad_kernel.h" +#include "paddle/phi/kernels/impl/slice_grad_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(slice_grad, + CPU, + ALL_LAYOUT, + phi::SliceGradRawKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/slice_kernel.cc b/paddle/phi/kernels/cpu/slice_kernel.cc new file mode 100644 index 0000000000000..736540609dd72 --- /dev/null +++ b/paddle/phi/kernels/cpu/slice_kernel.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/slice_kernel.h" +#include "paddle/phi/kernels/impl/slice_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(slice, + CPU, + ALL_LAYOUT, + phi::SliceRawKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/funcs/eigen/eigen_function.h b/paddle/phi/kernels/funcs/eigen/eigen_function.h index b763de5621a37..b971b4f95ef57 100644 --- a/paddle/phi/kernels/funcs/eigen/eigen_function.h +++ b/paddle/phi/kernels/funcs/eigen/eigen_function.h @@ -163,11 +163,11 @@ struct EigenPad { const InType& in, const Array& padding, const T value); - static void Eval(const EigenDevice& dev, - OutType32BitIndex out, - const InType32BitIndex& in, - const Array32Bit& padding, - const T value); + static void Eval32(const EigenDevice& dev, + OutType32BitIndex out, + const InType32BitIndex& in, + const Array32Bit& padding, + const T value); }; template diff --git a/paddle/phi/kernels/funcs/eigen/pad.cc b/paddle/phi/kernels/funcs/eigen/pad.cc index 7da72cab7690c..c457199b0a93c 100644 --- a/paddle/phi/kernels/funcs/eigen/pad.cc +++ b/paddle/phi/kernels/funcs/eigen/pad.cc @@ -41,11 +41,11 @@ struct EigenPad { out.device(dev) = in.pad(padding, value); } - static void Eval(const Eigen::DefaultDevice& dev, - OutType32BitIndex out, - const InType32BitIndex& in, - const Array32Bit& padding, - const T value) { + static void Eval32(const Eigen::DefaultDevice& dev, + OutType32BitIndex out, + const InType32BitIndex& in, + const Array32Bit& padding, + const T value) { out.device(dev) = in.pad(padding, value); } }; @@ -56,7 +56,8 @@ struct EigenPad { template struct FUNCTOR; \ template struct FUNCTOR; \ template struct FUNCTOR; \ - template struct FUNCTOR + template struct FUNCTOR; + INSTANTIATION(EigenPad, bool); INSTANTIATION(EigenPad, int); INSTANTIATION(EigenPad, int64_t); diff --git a/paddle/phi/kernels/funcs/eigen/pad.cu b/paddle/phi/kernels/funcs/eigen/pad.cu index 2978078e67339..7d8c2580d9621 100644 --- a/paddle/phi/kernels/funcs/eigen/pad.cu +++ b/paddle/phi/kernels/funcs/eigen/pad.cu @@ -42,11 +42,11 @@ struct EigenPad { out.device(dev) = in.pad(padding, value); } - static void Eval(const Eigen::GpuDevice& dev, - OutType32BitIndex out, - const InType32BitIndex& in, - const Array32Bit& padding, - const T value) { + static void Eval32(const Eigen::GpuDevice& dev, + OutType32BitIndex out, + const InType32BitIndex& in, + const Array32Bit& padding, + const T value) { out.device(dev) = in.pad(padding, value); } }; diff --git a/paddle/fluid/operators/slice_utils.h b/paddle/phi/kernels/funcs/slice_utils.h similarity index 72% rename from paddle/fluid/operators/slice_utils.h rename to paddle/phi/kernels/funcs/slice_utils.h index ed26b52675c9c..0c956248fd9ef 100644 --- a/paddle/fluid/operators/slice_utils.h +++ b/paddle/phi/kernels/funcs/slice_utils.h @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include +#include #include #include -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; +namespace phi { + +namespace funcs { template -inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, +inline void CheckAndUpdateSliceAttrs(const DDim in_dims, const std::vector& axes, std::vector* starts, std::vector* ends, @@ -31,11 +31,14 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, for (size_t i = 0; i < axes.size(); ++i) { T axis = axes[i]; PADDLE_ENFORCE_LT( - axis, in_dims.size(), - platform::errors::InvalidArgument( + axis, + in_dims.size(), + phi::errors::InvalidArgument( "The axis value should be less than the rank of input, " "but received axes[%d] = %d, rank of input is %d.", - i, axis, in_dims.size())); + i, + axis, + in_dims.size())); if (infer_flags != nullptr && (*infer_flags)[i] == -1) { continue; @@ -46,8 +49,10 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, if (dim_value > 0) { T step = steps == nullptr ? 1 : (*steps)[i]; PADDLE_ENFORCE_NE( - step, 0, platform::errors::InvalidArgument( - "Step should not be 0, but received step = %d.", step)); + step, + 0, + phi::errors::InvalidArgument( + "Step should not be 0, but received step = %d.", step)); T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i]; start = std::max(start, static_cast(0)); @@ -60,11 +65,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, start = std::min(start, dim_value); end = std::max(end, static_cast(0)); PADDLE_ENFORCE_GE( - end, start, - platform::errors::InvalidArgument( + end, + start, + phi::errors::InvalidArgument( "When step > 0, end should be greater than start, but " "received end = %d, start = %d.", - end, start)); + end, + start)); } else { // NOTE(liym27): When step < 0, start should less and equal to // dim_value-1 @@ -72,11 +79,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, start = std::min(start, dim_value - 1); end = std::max(end, static_cast(-1)); PADDLE_ENFORCE_GE( - start, end, - platform::errors::InvalidArgument( + start, + end, + phi::errors::InvalidArgument( "When step < 0, start should be greater than end, but " "received start = %d, end = %d.", - start, end)); + start, + end)); } (*starts)[i] = start; @@ -89,13 +98,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, } template -inline framework::DDim GetSliceDims(const framework::DDim in_dims, - const std::vector& axes, - const std::vector& starts, - const std::vector& ends, - std::vector* steps = nullptr, - std::vector* infer_flags = nullptr) { - framework::DDim slice_dims(in_dims); +inline phi::DDim GetSliceDims(const phi::DDim in_dims, + const std::vector& axes, + const std::vector& starts, + const std::vector& ends, + std::vector* steps = nullptr, + std::vector* infer_flags = nullptr) { + phi::DDim slice_dims(in_dims); for (size_t i = 0; i < axes.size(); ++i) { T axis = axes[i]; @@ -118,18 +127,19 @@ inline framework::DDim GetSliceDims(const framework::DDim in_dims, } template -inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims, - const std::vector& decrease_axes, - std::vector* infer_flags = nullptr) { - framework::DDim decreased_dims(slice_dims); +inline DDim GetDecreasedDims(const DDim slice_dims, + const std::vector& decrease_axes, + std::vector* infer_flags = nullptr) { + DDim decreased_dims(slice_dims); std::vector decrease_flag(slice_dims.size(), 0); if (decrease_axes.size() > 0) { for (size_t i = 0; i < decrease_axes.size(); ++i) { T axis = decrease_axes[i]; decrease_flag[axis] = 1; if (infer_flags && (*infer_flags)[i] != -1) { - PADDLE_ENFORCE_EQ(decreased_dims[axis], 1, - platform::errors::InvalidArgument( + PADDLE_ENFORCE_EQ(decreased_dims[axis], + 1, + phi::errors::InvalidArgument( "Decrease dim should be 1, but now received %d", decreased_dims[axis])); } @@ -153,5 +163,5 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims, return decreased_dims; } -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/slice_grad_kernel.cu.cc b/paddle/phi/kernels/gpu/slice_grad_kernel.cu.cc new file mode 100644 index 0000000000000..2769f5cc65d71 --- /dev/null +++ b/paddle/phi/kernels/gpu/slice_grad_kernel.cu.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/slice_grad_kernel.h" +#include "paddle/phi/kernels/impl/slice_grad_kernel_impl.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(slice_grad, + GPU, + ALL_LAYOUT, + phi::SliceGradRawKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex, + phi::dtype::bfloat16, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/slice_kernel.cu.cc b/paddle/phi/kernels/gpu/slice_kernel.cu.cc new file mode 100644 index 0000000000000..0fa61962c9eb0 --- /dev/null +++ b/paddle/phi/kernels/gpu/slice_kernel.cu.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/slice_kernel.h" +#include "paddle/phi/kernels/impl/slice_kernel_impl.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(slice, + GPU, + ALL_LAYOUT, + phi::SliceRawKernel, + bool, + int, + int64_t, + float, + double, + phi::dtype::complex, + phi::dtype::complex, + phi::dtype::bfloat16, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/set_value_kernel_impl.h b/paddle/phi/kernels/impl/set_value_kernel_impl.h index 99db559f3b816..cbe94efb43908 100644 --- a/paddle/phi/kernels/impl/set_value_kernel_impl.h +++ b/paddle/phi/kernels/impl/set_value_kernel_impl.h @@ -24,8 +24,7 @@ #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" - -#include "paddle/fluid/operators/slice_utils.h" +#include "paddle/phi/kernels/funcs/slice_utils.h" namespace phi { @@ -85,12 +84,12 @@ void SetValueImpl(const Context& dev_ctx, std::vector starts_local = starts.GetData(); std::vector ends_local = ends.GetData(); std::vector steps_local = steps.GetData(); - paddle::operators::CheckAndUpdateSliceAttrs( + phi::funcs::CheckAndUpdateSliceAttrs( in_dims, axes, &starts_local, &ends_local, &steps_local); - auto slice_dims = paddle::operators::GetSliceDims( + auto slice_dims = phi::funcs::GetSliceDims( in_dims, axes, starts_local, ends_local, &steps_local); auto decrease_slice_dims = - paddle::operators::GetDecreasedDims(slice_dims, decrease_axes); + phi::funcs::GetDecreasedDims(slice_dims, decrease_axes); auto slice_dims_for_assign = decrease_slice_dims; if (!none_axes.empty()) { diff --git a/paddle/phi/kernels/impl/slice_grad_kernel_impl.h b/paddle/phi/kernels/impl/slice_grad_kernel_impl.h new file mode 100644 index 0000000000000..1dbb5bd142c52 --- /dev/null +++ b/paddle/phi/kernels/impl/slice_grad_kernel_impl.h @@ -0,0 +1,354 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/slice_grad_kernel.h" + +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/slice_utils.h" + +namespace phi { + +template +void LaunchEigenPadding( + const Context& context, + DenseTensor* d_input, + const DDim& in_dims, + const DenseTensor* d_out, + const DDim& out_dims, + const Eigen::array, D>& paddings) { + auto& place = *context.eigen_device(); + auto d_in_t = EigenTensor::From( + *d_input, in_dims); + auto d_out_t = EigenTensor::From( + *d_out, out_dims); + + if (d_input->numel() <= Eigen::NumTraits::highest()) { + // similar to tf.pad: + // if element number less than INT_MAX, change the type of index to int + Eigen::array, D> paddings_32bit; + for (size_t i = 0; i < D; i++) { + paddings_32bit[i] = std::make_pair(paddings[i].first, paddings[i].second); + } + funcs::EigenPad, T, D>::Eval32( + place, + To32BitIndex(d_in_t), + To32BitIndex(d_out_t), + paddings_32bit, + static_cast(0)); + } else { + funcs::EigenPad, T, D>::Eval( + place, d_in_t, d_out_t, paddings, static_cast(0)); + } +} + +template +void EigenPaddingCompute( + const Context& context, + DenseTensor* d_input, + const DDim& in_dims, + const DenseTensor* d_out, + const DDim& out_dims, + const Eigen::array, D>& paddings) { + if (D <= 3) { + // if dimension less than 3, cannot reduce dimension + LaunchEigenPadding( + context, d_input, in_dims, d_out, out_dims, paddings); + } else { // else we can reduce dimension + // count not-zero padding number, and record the dimension + int need_pad_num = 0, pad_dim = -1; + for (size_t i = 0; i < D; i++) { + if (paddings[i].first != 0 || paddings[i].second != 0) { + need_pad_num++; + pad_dim = i; + } + } + + if (need_pad_num == 1) { + // only need padding one dimension, we can reduce dimension. + // only the padding dimension is available for us. + // How to reduce dimension(5 to 3 for example): + // before(D=5): + // in_dims: [x1, x2, x3, x4, x5] + // padding.first: [0, 0, a, 0, 0] + // padding.second: [0, 0, b, 0, 0] + // | | + // V V + // after(D=3): + // reshaped_in_dims: [x1*x2, x3, x4*x5] + // reshaped_padding.first: [0, a, 0] + // reshaped_padding.second: [0, b, 0] + + if (pad_dim == D - 1) { + // only last dimension need padding, + // reshape the dimension of tensor in 2: [preceding, padding] + std::vector in_tore_shape(2, 1), out_tore_shape(2, 1); + Eigen::array, 2> reshaped_padding; + + // first dimension is the accumulate of preceding dimension + for (int i = 0; i < pad_dim; i++) { + in_tore_shape[0] *= in_dims[i]; + out_tore_shape[0] *= out_dims[i]; + } + // second dimension is the padding dimension + in_tore_shape[1] = in_dims[pad_dim]; + out_tore_shape[1] = out_dims[pad_dim]; + + // convert array from std::vector to DDim + DDim reshaped_in_dims = make_ddim(in_tore_shape); + DDim reshaped_out_dims = make_ddim(out_tore_shape); + + // after reshape: the first dimension do not need padding, + // set padding[0] zero + reshaped_padding[0].first = reshaped_padding[0].second = 0; + // the second dimension is the previous padding dimension + reshaped_padding[1].first = paddings[pad_dim].first; + reshaped_padding[1].second = paddings[pad_dim].second; + + LaunchEigenPadding(context, + d_input, + reshaped_in_dims, + d_out, + reshaped_out_dims, + reshaped_padding); + } else if (pad_dim == 0) { + // only first dimension need padding, + // reshape the dimension of tensor in 2: [padding, succeeding] + // similar to (D - 1) + std::vector in_tore_shape(2, 1), out_tore_shape(2, 1); + Eigen::array, 2> reshaped_padding; + + // first dimension is the padding dimension + in_tore_shape[0] = in_dims[pad_dim]; + out_tore_shape[0] = out_dims[pad_dim]; + // sencond dimension is the accumulate of succeeding dimension + for (size_t i = pad_dim + 1; i < D; i++) { + in_tore_shape[1] *= in_dims[i]; + out_tore_shape[1] *= out_dims[i]; + } + + // convert array from std::vector to DDim + DDim reshaped_in_dims = make_ddim(in_tore_shape); + DDim reshaped_out_dims = make_ddim(out_tore_shape); + + // after reshape: + // the first dimension is the previous padding dimension + reshaped_padding[0].first = paddings[pad_dim].first; + reshaped_padding[0].second = paddings[pad_dim].second; + // the second dimension do not need padding, set padding[1] zero + reshaped_padding[1].first = reshaped_padding[1].second = 0; + + LaunchEigenPadding(context, + d_input, + reshaped_in_dims, + d_out, + reshaped_out_dims, + reshaped_padding); + } else { + // other dimension need padding + // reshape the dimension of tensor in 3: + // [preceding, padding, succeeding] + std::vector in_tore_shape(3, 1), out_tore_shape(3, 1); + Eigen::array, 3> reshaped_padding; + + // first dimension is the accumulate of preceding dimension + for (int i = 0; i < pad_dim; i++) { + in_tore_shape[0] *= in_dims[i]; + out_tore_shape[0] *= out_dims[i]; + } + // second dimension is the padding dimension + in_tore_shape[1] = in_dims[pad_dim]; + out_tore_shape[1] = out_dims[pad_dim]; + // third dimension is the accumulate of succeeding dimension + for (size_t i = pad_dim + 1; i < D; i++) { + in_tore_shape[2] *= in_dims[i]; + out_tore_shape[2] *= out_dims[i]; + } + + // convert array from std::vector to DDim + DDim reshaped_in_dims = make_ddim(in_tore_shape); + DDim reshaped_out_dims = make_ddim(out_tore_shape); + + // after reshape: + // the first dimension do not need padding, set padding[0] zero + reshaped_padding[0].first = reshaped_padding[2].second = 0; + // the second dimension is the previous padding dimension + reshaped_padding[1].first = paddings[pad_dim].first; + reshaped_padding[1].second = paddings[pad_dim].second; + // the third dimension do not need padding, set padding[2] zero + reshaped_padding[2].first = reshaped_padding[2].second = 0; + + LaunchEigenPadding(context, + d_input, + reshaped_in_dims, + d_out, + reshaped_out_dims, + reshaped_padding); + } + } else { + // need padding at many dimension, cannot reduce dimension + LaunchEigenPadding( + context, d_input, in_dims, d_out, out_dims, paddings); + } + } +} + +template +void SliceGradCompute(const Context& ctx, + const DenseTensor& out_grad, + const std::vector& axes, + const std::vector& starts, + const std::vector& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* input_grad) { + auto* d_out = &out_grad; + auto* d_input = input_grad; + ctx.template Alloc(d_input); + + auto out_dims = d_out->dims(); + auto in_dims = d_input->dims(); + + auto decrease_size = decrease_axis.size(); + if (decrease_size > 0) { + if (decrease_size == static_cast(in_dims.size())) { + // all dims decrease + std::vector origin_out_shape(decrease_size, 1); + out_dims = make_ddim(std::vector(decrease_size, 1)); + } else { + std::vector origin_out_shape(out_dims.size() + decrease_size, -1); + for (size_t i = 0; i < decrease_size; ++i) { + origin_out_shape[decrease_axis[i]] = 1; + } + + int index = 0; + for (size_t i = 0; i < origin_out_shape.size(); ++i) { + if (origin_out_shape[i] == -1) { + origin_out_shape[i] = out_dims[index]; + ++index; + } + } + + out_dims = make_ddim(origin_out_shape); + } + } + + auto offsets = Eigen::array(); + auto extents = Eigen::array(); + for (size_t i = 0; i < D; ++i) { + offsets[i] = 0; + extents[i] = out_dims[i]; + } + + for (size_t i = 0; i < axes.size(); ++i) { + int axis = axes[i]; + int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i]; + start = std::max(start, static_cast(0)); + offsets[axis] = start; + } + + Eigen::array, D> paddings; + for (size_t i = 0; i < paddings.size(); ++i) { + paddings[i].first = offsets[i]; + paddings[i].second = (in_dims[i] - out_dims[i]) - offsets[i]; + } + EigenPaddingCompute( + ctx, d_input, in_dims, d_out, out_dims, paddings); +} + +template +void SliceGradRawKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& out_grad, + const std::vector& axes, + const ScalarArray& starts_arr, + const ScalarArray& ends_arr, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* input_grad) { + size_t rank = input.dims().size(); + + auto& starts = starts_arr.GetData(); + auto& ends = ends_arr.GetData(); + + switch (rank) { + case 1: + SliceGradCompute(ctx, + out_grad, + axes, + starts, + ends, + infer_flags, + decrease_axis, + input_grad); + break; + case 2: + SliceGradCompute(ctx, + out_grad, + axes, + starts, + ends, + infer_flags, + decrease_axis, + input_grad); + break; + case 3: + SliceGradCompute(ctx, + out_grad, + axes, + starts, + ends, + infer_flags, + decrease_axis, + input_grad); + break; + case 4: + SliceGradCompute(ctx, + out_grad, + axes, + starts, + ends, + infer_flags, + decrease_axis, + input_grad); + break; + case 5: + SliceGradCompute(ctx, + out_grad, + axes, + starts, + ends, + infer_flags, + decrease_axis, + input_grad); + break; + case 6: + SliceGradCompute(ctx, + out_grad, + axes, + starts, + ends, + infer_flags, + decrease_axis, + input_grad); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of input should be less than 7, but received %d.", rank)); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/slice_kernel_impl.h b/paddle/phi/kernels/impl/slice_kernel_impl.h new file mode 100644 index 0000000000000..5c127358e8eee --- /dev/null +++ b/paddle/phi/kernels/impl/slice_kernel_impl.h @@ -0,0 +1,154 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/slice_utils.h" + +namespace phi { + +template +void SliceCompute(const Context& ctx, + const DenseTensor& input, + const std::vector& axes, + const std::vector& starts_t, + const std::vector& ends_t, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* out) { + // Step 1: Get the accurate attribute value of starts and ends + std::vector starts = starts_t; + std::vector ends = ends_t; + PADDLE_ENFORCE_EQ( + starts.size(), + axes.size(), + phi::errors::InvalidArgument( + "The size of starts must be equal to the size of axes.")); + PADDLE_ENFORCE_EQ(ends.size(), + axes.size(), + phi::errors::InvalidArgument( + "The size of ends must be equal to the size of axes.")); + + // Step 2: Compute output + auto in = &input; + + auto in_dims = in->dims(); + auto out_dims = out->dims(); + auto slice_dims = out_dims; + + // 2.1 Infer output dims + for (size_t i = 0; i < axes.size(); ++i) { + // when start == -1 && end == start+1 + if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) { + auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]); + if (ret != decrease_axis.end()) { + ends[i] = in_dims[axes[i]]; + } + } + } + + funcs::CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends); + slice_dims = funcs::GetSliceDims( + in_dims, axes, starts, ends, nullptr, nullptr); + out_dims = funcs::GetDecreasedDims(slice_dims, decrease_axis); + + // 2.2 Get output + auto offsets = Eigen::DSizes(); + auto extents = Eigen::DSizes(); + + for (size_t i = 0; i < D; ++i) { + offsets[i] = 0; + extents[i] = slice_dims[i]; + } + for (size_t i = 0; i < axes.size(); ++i) { + offsets[axes[i]] = starts[i]; + } + + out->Resize(slice_dims); + ctx.template Alloc(out); + + auto in_t = EigenTensor::From(*in, in_dims); + auto out_t = EigenTensor::From(*out, slice_dims); + auto& eigen_place = *ctx.eigen_device(); + + if (in->numel() <= Eigen::NumTraits::highest()) { + // similar to tf.slice: + // if element number less than INT_MAX, change the type of index to int + Eigen::DSizes offsets_32bit, extents_32bit; + for (size_t i = 0; i < D; i++) { + offsets_32bit[i] = offsets[i]; + extents_32bit[i] = extents[i]; + } + funcs::EigenSlice, T, D>::Eval( + eigen_place, + To32BitIndex(out_t), + To32BitIndex(in_t), + offsets_32bit, + extents_32bit); + } else { + funcs::EigenSlice, T, D>::Eval( + eigen_place, out_t, in_t, offsets, extents); + } + + out->Resize(out_dims); +} + +template +void SliceRawKernel(const Context& ctx, + const DenseTensor& input, + const std::vector& axes, + const ScalarArray& starts_arr, + const ScalarArray& ends_arr, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* out) { + int rank = input.dims().size(); + + auto& starts = starts_arr.GetData(); + auto& ends = ends_arr.GetData(); + + switch (rank) { + case 1: + SliceCompute( + ctx, input, axes, starts, ends, infer_flags, decrease_axis, out); + break; + case 2: + SliceCompute( + ctx, input, axes, starts, ends, infer_flags, decrease_axis, out); + break; + case 3: + SliceCompute( + ctx, input, axes, starts, ends, infer_flags, decrease_axis, out); + break; + case 4: + SliceCompute( + ctx, input, axes, starts, ends, infer_flags, decrease_axis, out); + break; + case 5: + SliceCompute( + ctx, input, axes, starts, ends, infer_flags, decrease_axis, out); + break; + case 6: + SliceCompute( + ctx, input, axes, starts, ends, infer_flags, decrease_axis, out); + break; + default: + PADDLE_THROW(phi::errors::InvalidArgument( + "The rank of input should be less than 7, but received %d.", rank)); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/slice_grad_kernel.h b/paddle/phi/kernels/slice_grad_kernel.h new file mode 100644 index 0000000000000..a7ee9ffde4eb0 --- /dev/null +++ b/paddle/phi/kernels/slice_grad_kernel.h @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SliceGradRawKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& out_grad, + const std::vector& axes, + const ScalarArray& starts, + const ScalarArray& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* input_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/slice_kernel.h b/paddle/phi/kernels/slice_kernel.h new file mode 100644 index 0000000000000..ff27824b9e676 --- /dev/null +++ b/paddle/phi/kernels/slice_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SliceRawKernel(const Context& ctx, + const DenseTensor& input, + const std::vector& axes, + const ScalarArray& starts, + const ScalarArray& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/slice_sig.cc b/paddle/phi/ops/compat/slice_sig.cc new file mode 100644 index 0000000000000..ba3bafdaa51c7 --- /dev/null +++ b/paddle/phi/ops/compat/slice_sig.cc @@ -0,0 +1,183 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SliceOpArgumentMapping(const ArgumentMappingContext& ctx) { + // if input is Tensor Array + if (ctx.IsDenseTensorVectorInput("Input")) { + return KernelSignature("unregistered", {}, {}, {}); + } + + if (ctx.HasInput("StartsTensor")) { + if (ctx.HasInput("EndsTensor")) { + return KernelSignature("slice", + {"Input"}, + {"axes", + "StartsTensor", + "EndsTensor", + "infer_flags", + "decrease_axis"}, + {"Out"}); + } else if (ctx.InputSize("EndsTensorList") > 0) { + return KernelSignature("slice", + {"Input"}, + {"axes", + "StartsTensor", + "EndsTensorList", + "infer_flags", + "decrease_axis"}, + {"Out"}); + } else { + return KernelSignature( + "slice", + {"Input"}, + {"axes", "StartsTensor", "ends", "infer_flags", "decrease_axis"}, + {"Out"}); + } + } else if (ctx.InputSize("StartsTensorList") > 0) { + if (ctx.HasInput("EndsTensor")) { + return KernelSignature("slice", + {"Input"}, + {"axes", + "StartsTensorList", + "EndsTensor", + "infer_flags", + "decrease_axis"}, + {"Out"}); + } else if (ctx.InputSize("EndsTensorList") > 0) { + return KernelSignature("slice", + {"Input"}, + {"axes", + "StartsTensorList", + "EndsTensorList", + "infer_flags", + "decrease_axis"}, + {"Out"}); + } else { + return KernelSignature( + "slice", + {"Input"}, + {"axes", "StartsTensorList", "ends", "infer_flags", "decrease_axis"}, + {"Out"}); + } + } else { + if (ctx.HasInput("EndsTensor")) { + return KernelSignature( + "slice", + {"Input"}, + {"axes", "starts", "EndsTensor", "infer_flags", "decrease_axis"}, + {"Out"}); + } else if (ctx.InputSize("EndsTensorList") > 0) { + return KernelSignature( + "slice", + {"Input"}, + {"axes", "starts", "EndsTensorList", "infer_flags", "decrease_axis"}, + {"Out"}); + } else { + return KernelSignature( + "slice", + {"Input"}, + {"axes", "starts", "ends", "infer_flags", "decrease_axis"}, + {"Out"}); + } + } +} + +KernelSignature SliceGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorVectorInput("Input")) { + return KernelSignature("unregistered", {}, {}, {}); + } + + if (ctx.HasInput("StartsTensor")) { + if (ctx.HasInput("EndsTensor")) { + return KernelSignature("slice_grad", + {"Input", GradVarName("Out")}, + {"axes", + "StartsTensor", + "EndsTensor", + "infer_flags", + "decrease_axis"}, + {GradVarName("Input")}); + } else if (ctx.InputSize("EndsTensorList") > 0) { + return KernelSignature("slice_grad", + {"Input", GradVarName("Out")}, + {"axes", + "StartsTensor", + "EndsTensorList", + "infer_flags", + "decrease_axis"}, + {GradVarName("Input")}); + } else { + return KernelSignature( + "slice_grad", + {"Input", GradVarName("Out")}, + {"axes", "StartsTensor", "ends", "infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + } + } else if (ctx.InputSize("StartsTensorList") > 0) { + if (ctx.HasInput("EndsTensor")) { + return KernelSignature("slice_grad", + {"Input", GradVarName("Out")}, + {"axes", + "StartsTensorList", + "EndsTensor", + "infer_flags", + "decrease_axis"}, + {GradVarName("Input")}); + } else if (ctx.InputSize("EndsTensorList") > 0) { + return KernelSignature("slice_grad", + {"Input", GradVarName("Out")}, + {"axes", + "StartsTensorList", + "EndsTensorList", + "infer_flags", + "decrease_axis"}, + {GradVarName("Input")}); + } else { + return KernelSignature( + "slice_grad", + {"Input", GradVarName("Out")}, + {"axes", "StartsTensorList", "ends", "infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + } + } else { + if (ctx.HasInput("EndsTensor")) { + return KernelSignature( + "slice_grad", + {"Input", GradVarName("Out")}, + {"axes", "starts", "EndsTensor", "infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + } else if (ctx.InputSize("EndsTensorList") > 0) { + return KernelSignature( + "slice_grad", + {"Input", GradVarName("Out")}, + {"axes", "starts", "EndsTensorList", "infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + } else { + return KernelSignature( + "slice_grad", + {"Input", GradVarName("Out")}, + {"axes", "starts", "ends", "infer_flags", "decrease_axis"}, + {GradVarName("Input")}); + } + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(slice, phi::SliceOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(slice_grad, phi::SliceGradOpArgumentMapping); diff --git a/paddle/pten/kernels/slice_kernel.h b/paddle/pten/kernels/slice_kernel.h new file mode 100644 index 0000000000000..ff27824b9e676 --- /dev/null +++ b/paddle/pten/kernels/slice_kernel.h @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar_array.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void SliceRawKernel(const Context& ctx, + const DenseTensor& input, + const std::vector& axes, + const ScalarArray& starts, + const ScalarArray& ends, + const std::vector& infer_flags, + const std::vector& decrease_axis, + DenseTensor* out); + +} // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_slice_op.py b/python/paddle/fluid/tests/unittests/test_slice_op.py index 629d61d01b283..71869b96aedf0 100644 --- a/python/paddle/fluid/tests/unittests/test_slice_op.py +++ b/python/paddle/fluid/tests/unittests/test_slice_op.py @@ -796,4 +796,5 @@ def test_input_cuda_pinned_var(self): if __name__ == '__main__': + paddle.enable_static() unittest.main()