From 32f42e941b0efeb397b4fd16bf00b0281c09d014 Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Tue, 30 Aug 2022 14:52:55 +0800 Subject: [PATCH] [OpAttr]Adapt tensor axis for reduce_min/max/mean/sum/prod (#45078) * [OpAttr]Adapt tensor axis for reduce_min/max/mean/sum/prod --- .../operators/reduce_ops/reduce_max_op.cc | 7 +- .../operators/reduce_ops/reduce_mean_op.cc | 7 +- .../operators/reduce_ops/reduce_min_op.cc | 7 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 3 +- .../operators/reduce_ops/reduce_prod_op.cc | 7 +- paddle/phi/api/yaml/legacy_api.yaml | 18 ++-- paddle/phi/api/yaml/legacy_backward.yaml | 28 +++--- paddle/phi/infermeta/unary.cc | 74 ++++++++++++-- paddle/phi/infermeta/unary.h | 23 ++++- .../cpu/graph_send_ue_recv_grad_kernel.cc | 8 +- .../kernels/cpu/graph_send_uv_grad_kernel.cc | 4 +- .../phi/kernels/cpu/identity_loss_kernel.cc | 2 +- .../phi/kernels/cpu/matrix_rank_tol_kernel.cc | 2 +- paddle/phi/kernels/cpu/reduce_max_kernel.cc | 4 +- .../kernels/cpu/reduce_mean_grad_kernel.cc | 12 ++- paddle/phi/kernels/cpu/reduce_mean_kernel.cc | 4 +- paddle/phi/kernels/cpu/reduce_min_kernel.cc | 4 +- paddle/phi/kernels/cpu/reduce_prod_kernel.cc | 4 +- .../phi/kernels/cpu/reduce_sum_grad_kernel.cc | 18 +++- paddle/phi/kernels/cpu/reduce_sum_kernel.cc | 4 +- .../gpu/graph_send_ue_recv_grad_kernel.cu | 8 +- .../kernels/gpu/graph_send_uv_grad_kernel.cu | 4 +- .../phi/kernels/gpu/matrix_rank_tol_kernel.cu | 2 +- .../kernels/gpu/reduce_mean_grad_kernel.cu | 6 +- .../phi/kernels/gpu/reduce_sum_grad_kernel.cu | 4 +- paddle/phi/kernels/impl/einsum_grad_impl.h | 2 +- paddle/phi/kernels/impl/einsum_impl.h | 3 +- paddle/phi/kernels/impl/lstsq_kernel_impl.h | 4 +- .../impl/reduce_max_grad_kernel_impl.h | 4 +- .../impl/reduce_min_grad_kernel_impl.h | 4 +- .../impl/reduce_prod_grad_kernel_impl.h | 5 +- paddle/phi/kernels/kps/reduce_max_kernel.cu | 6 +- paddle/phi/kernels/kps/reduce_mean_kernel.cu | 6 +- paddle/phi/kernels/kps/reduce_min_kernel.cu | 6 +- paddle/phi/kernels/kps/reduce_prod_kernel.cu | 6 +- paddle/phi/kernels/kps/reduce_sum_kernel.cu | 6 +- paddle/phi/kernels/reduce_max_grad_kernel.h | 3 +- paddle/phi/kernels/reduce_max_kernel.cc | 2 +- paddle/phi/kernels/reduce_max_kernel.h | 5 +- paddle/phi/kernels/reduce_mean_grad_kernel.h | 3 +- paddle/phi/kernels/reduce_mean_kernel.cc | 2 +- paddle/phi/kernels/reduce_mean_kernel.h | 6 +- paddle/phi/kernels/reduce_min_grad_kernel.h | 3 +- paddle/phi/kernels/reduce_min_kernel.cc | 2 +- paddle/phi/kernels/reduce_min_kernel.h | 5 +- paddle/phi/kernels/reduce_prod_grad_kernel.h | 3 +- paddle/phi/kernels/reduce_prod_kernel.cc | 2 +- paddle/phi/kernels/reduce_prod_kernel.h | 5 +- paddle/phi/kernels/reduce_sum_grad_kernel.h | 3 +- paddle/phi/kernels/reduce_sum_kernel.cc | 2 +- paddle/phi/kernels/reduce_sum_kernel.h | 7 +- paddle/phi/tests/kernels/test_sum_dev_api.cc | 4 +- .../fluid/tests/unittests/test_max_op.py | 37 +++++-- .../fluid/tests/unittests/test_mean_op.py | 25 +++++ .../fluid/tests/unittests/test_min_op.py | 38 ++++++-- .../fluid/tests/unittests/test_prod_op.py | 25 +++++ .../fluid/tests/unittests/test_sum_op.py | 97 +++++++++++++++++++ python/paddle/tensor/math.py | 78 +++++++++------ python/paddle/tensor/stat.py | 29 ++++-- 59 files changed, 513 insertions(+), 189 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/reduce_max_op.cc b/paddle/fluid/operators/reduce_ops/reduce_max_op.cc index bfc89403d8dd7..21e16a5cd14c6 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_max_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_max_op.cc @@ -25,9 +25,10 @@ class ReduceMaxOpMaker : public ops::ReduceOpMaker { virtual std::string GetOpType() const { return "Reduce reduce_max"; } }; -DECLARE_INFER_SHAPE_FUNCTOR(reduce_max, - ReduceMaxInferShapeFunctor, - PD_INFER_META(phi::ReduceInferMetaBase)); +DECLARE_INFER_SHAPE_FUNCTOR( + reduce_max, + ReduceMaxInferShapeFunctor, + PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase)); REGISTER_OPERATOR( reduce_max, diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc index 9afa493e4ad9e..a5827b1e0a9d1 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cc @@ -97,9 +97,10 @@ class __reduce_meanMaker__ : public ops::ReduceOpMaker { virtual std::string GetOpType() const { return "Reduce reduce_mean"; } }; -DECLARE_INFER_SHAPE_FUNCTOR(reduce_mean, - ReduceMeanInferShapeFunctor, - PD_INFER_META(phi::ReduceInferMetaBase)); +DECLARE_INFER_SHAPE_FUNCTOR( + reduce_mean, + ReduceMeanInferShapeFunctor, + PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase)); REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, diff --git a/paddle/fluid/operators/reduce_ops/reduce_min_op.cc b/paddle/fluid/operators/reduce_ops/reduce_min_op.cc index 2dced4fecee12..e9fafce1332d2 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_min_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_min_op.cc @@ -25,9 +25,10 @@ class ReduceMinOpMaker : public ops::ReduceOpMaker { virtual std::string GetOpType() const { return "Reduce reduce_min"; } }; -DECLARE_INFER_SHAPE_FUNCTOR(reduce_min, - ReduceMinInferShapeFunctor, - PD_INFER_META(phi::ReduceInferMetaBase)); +DECLARE_INFER_SHAPE_FUNCTOR( + reduce_min, + ReduceMinInferShapeFunctor, + PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase)); REGISTER_OPERATOR( reduce_min, diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 9e53a6b56de5c..df7804dc7a68d 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -698,7 +698,8 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { "Must be in the range [-rank(input), rank(input)). " "If `dim[i] < 0`, the dims[i] to reduce is `rank + dims[i]`. " "Note that reducing on the first dim will make the LoD info lost.") - .SetDefault({0}); + .SetDefault({0}) + .SupportTensor(); AddAttr("keep_dim", "(bool, default false) " "If true, retain the reduced dimension with length 1.") diff --git a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc index 578954663c7f5..0b78e80746412 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_prod_op.cc @@ -35,9 +35,10 @@ class ReduceProdOpMaker : public ops::ReduceOpMaker { virtual std::string GetOpType() const { return "Reduce reduce_prod"; } }; -DECLARE_INFER_SHAPE_FUNCTOR(reduce_prod, - ReduceProdInferShapeFunctor, - PD_INFER_META(phi::ReduceInferMetaBase)); +DECLARE_INFER_SHAPE_FUNCTOR( + reduce_prod, + ReduceProdInferShapeFunctor, + PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase)); REGISTER_OPERATOR( reduce_prod, diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index c74e0b0f73a02..4dff5c4ac8c8f 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1668,10 +1668,10 @@ func : matrix_rank_tol - api : max - args : (Tensor x, int64_t[] dims={}, bool keep_dim=false) + args : (Tensor x, IntArray dims={}, bool keep_dim=false) output : Tensor(out) infer_meta : - func : ReduceInferMeta + func : ReduceIntArrayAxisInferMeta kernel : func : max backward : max_grad @@ -1713,10 +1713,10 @@ backward : maxout_grad - api : mean - args : (Tensor x, int64_t[] dims={}, bool keep_dim=false) + args : (Tensor x, IntArray dims={}, bool keep_dim=false) output : Tensor(out) infer_meta : - func : ReduceInferMeta + func : ReduceIntArrayAxisInferMeta kernel : func : mean backward : mean_grad @@ -1762,10 +1762,10 @@ backward : meshgrid_grad - api : min - args : (Tensor x, int64_t[] dims={}, bool keep_dim=false) + args : (Tensor x, IntArray dims={}, bool keep_dim=false) output : Tensor(out) infer_meta : - func : ReduceInferMeta + func : ReduceIntArrayAxisInferMeta kernel : func : min backward : min_grad @@ -2091,10 +2091,10 @@ backward : reciprocal_grad - api : reduce_prod - args : (Tensor x, int64_t[] dims, bool keep_dim, bool reduce_all) + args : (Tensor x, IntArray dims, bool keep_dim, bool reduce_all) output : Tensor infer_meta : - func : ReduceInferMetaBase + func : ReduceIntArrayAxisInferMetaBase kernel : func : prod_raw backward : reduce_prod_grad @@ -2555,7 +2555,7 @@ backward : subtract_grad - api : sum - args : (Tensor x, int64_t[] dims={}, DataType out_dtype=DataType::UNDEFINED, bool keep_dim=false) + args : (Tensor x, IntArray dims={}, DataType out_dtype=DataType::UNDEFINED, bool keep_dim=false) output : Tensor(out) infer_meta : func : SumInferMeta diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 0a3326574ee67..a19b54be72425 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1451,8 +1451,8 @@ func : matrix_power_grad - backward_api : max_grad - forward: max (Tensor x, int64_t[] dims={}, bool keep_dim=false) -> Tensor(out) - args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] dims={}, bool keep_dim=false, bool reduce_all=false) + forward: max (Tensor x, IntArray dims={}, bool keep_dim=false) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, IntArray dims={}, bool keep_dim=false, bool reduce_all=false) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta @@ -1509,14 +1509,14 @@ func : mean_all_grad - backward_api : mean_double_grad - forward: mean_grad (Tensor x, Tensor grad_out, int64_t[] dims={}, bool keep_dim=false, bool reduce_all = false) -> Tensor(grad_x) - args : (Tensor grad_x_grad, int64_t[] dims={}, bool keep_dim=false) + forward: mean_grad (Tensor x, Tensor grad_out, IntArray dims={}, bool keep_dim=false, bool reduce_all = false) -> Tensor(grad_x) + args : (Tensor grad_x_grad, IntArray dims={}, bool keep_dim=false) output : Tensor(grad_out_grad) invoke : mean(grad_x_grad, dims, keep_dim) - backward_api : mean_grad - forward: mean (Tensor x, int64_t[] dims={}, bool keep_dim=false) -> Tensor(out) - args : (Tensor x, Tensor out_grad, int64_t[] dims={}, bool keep_dim=false, bool reduce_all=false) + forward: mean (Tensor x, IntArray dims={}, bool keep_dim=false) -> Tensor(out) + args : (Tensor x, Tensor out_grad, IntArray dims={}, bool keep_dim=false, bool reduce_all=false) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta @@ -1536,8 +1536,8 @@ func : meshgrid_grad - backward_api : min_grad - forward: min (Tensor x, int64_t[] dims={}, bool keep_dim=false) -> Tensor(out) - args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] dims={}, bool keep_dim=false, bool reduce_all=false) + forward: min (Tensor x, IntArray dims={}, bool keep_dim=false) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, IntArray dims={}, bool keep_dim=false, bool reduce_all=false) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta @@ -1849,8 +1849,8 @@ inplace : (out_grad -> x_grad) - backward_api : reduce_prod_grad - forward : reduce_prod (Tensor x, int64_t[] dims, bool keep_dim, bool reduce_all) -> Tensor(out) - args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] dims, bool keep_dim, bool reduce_all) + forward : reduce_prod (Tensor x, IntArray dims, bool keep_dim, bool reduce_all) -> Tensor(out) + args : (Tensor x, Tensor out, Tensor out_grad, IntArray dims, bool keep_dim, bool reduce_all) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta @@ -2386,14 +2386,14 @@ inplace : (out_grad -> x_grad) - backward_api : sum_double_grad - forward : sum_grad (Tensor x, Tensor grad_out, int64_t[] dims, bool keep_dim, bool reduce_all=false) -> Tensor(grad_x) - args : (Tensor grad_x_grad, int64_t[] dims={}, bool keep_dim=false) + forward : sum_grad (Tensor x, Tensor grad_out, IntArray dims, bool keep_dim, bool reduce_all=false) -> Tensor(grad_x) + args : (Tensor grad_x_grad, IntArray dims={}, bool keep_dim=false) output : Tensor(grad_out_grad) invoke : sum(grad_x_grad, dims, grad_x_grad.dtype(), keep_dim) - backward_api : sum_grad - forward : sum (Tensor x, int64_t[] dims={}, DataType out_dtype=DataType::UNDEFINED, bool keep_dim=false) -> Tensor(out) - args : (Tensor x, Tensor out_grad, int64_t[] dims, bool keep_dim, bool reduce_all=false) + forward : sum (Tensor x, IntArray dims={}, DataType out_dtype=DataType::UNDEFINED, bool keep_dim=false) -> Tensor(out) + args : (Tensor x, Tensor out_grad, IntArray dims, bool keep_dim, bool reduce_all=false) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 82ebc89c666b5..56e524b1bbb91 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2656,6 +2656,33 @@ DDim ReduceInferDim(const MetaTensor& x, return out_dim; } +DDim ReduceInferDimForIntArrayAxis(const MetaTensor& x, + const IntArray& axis, + bool keep_dim, + bool reduce_all) { + std::vector vec_axis = axis.GetData(); + std::vector vec_dim; + if (reduce_all) { + if (keep_dim) { + vec_dim = std::vector(x.dims().size(), 1); + } else { + vec_dim = {1}; + } + } else { + if (keep_dim) { + vec_dim = std::vector(x.dims().size(), -1); + } else { + auto x_rank = static_cast(x.dims().size()); + if (vec_axis.size() >= x_rank) { + vec_dim = {-1}; + } else { + vec_dim = std::vector(x.dims().size() - vec_axis.size(), -1); + } + } + } + return phi::make_ddim(vec_dim); +} + void ReduceInferMeta(const MetaTensor& x, const std::vector& axis, bool keep_dim, @@ -2678,6 +2705,34 @@ void ReduceInferMetaBase(const MetaTensor& x, out->set_layout(x.layout()); } +void ReduceIntArrayAxisInferMetaBase(const MetaTensor& x, + const IntArray& axis, + bool keep_dim, + bool reduce_all, + MetaTensor* out, + MetaConfig config) { + if (config.is_runtime || !axis.FromTensor()) { + ReduceInferMetaBase(x, axis.GetData(), keep_dim, reduce_all, out); + } else { + DDim out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all); + out->set_dims(out_dim); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); + } +} + +void ReduceIntArrayAxisInferMeta(const MetaTensor& x, + const IntArray& axis, + bool keep_dim, + MetaTensor* out, + MetaConfig config) { + bool reduce_all = false; + if (axis.size() == 0) { + reduce_all = true; + } + ReduceIntArrayAxisInferMetaBase(x, axis, keep_dim, reduce_all, out, config); +} + void RepeatInterleaveInferMeta(const MetaTensor& x, int repeats, int dim, @@ -3354,24 +3409,31 @@ void StridedSliceInferMeta(const MetaTensor& x, api.yaml */ void SumInferMeta(const MetaTensor& x, - const std::vector& axis, + const IntArray& axis, DataType dtype, bool keep_dim, - MetaTensor* out) { + MetaTensor* out, + MetaConfig config) { bool reduce_all = false; if (axis.size() == 0) { reduce_all = true; } - SumRawInferMeta(x, axis, keep_dim, reduce_all, dtype, out); + SumRawInferMeta(x, axis, keep_dim, reduce_all, dtype, out, config); } void SumRawInferMeta(const MetaTensor& x, - const std::vector& axis, + const IntArray& axis, bool keep_dim, bool reduce_all, DataType dtype, - MetaTensor* out) { - DDim out_dim = ReduceInferDim(x, axis, keep_dim, reduce_all); + MetaTensor* out, + MetaConfig config) { + DDim out_dim; + if (config.is_runtime || !axis.FromTensor()) { + out_dim = ReduceInferDim(x, axis.GetData(), keep_dim, reduce_all); + } else { + out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all); + } DataType out_dtype; if (dtype != DataType::UNDEFINED) { diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 1d1995db43f64..9a67066cab2c5 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -372,6 +372,19 @@ void ReduceInferMetaBase(const MetaTensor& x, bool reduce_all, MetaTensor* out); +void ReduceIntArrayAxisInferMetaBase(const MetaTensor& x, + const IntArray& axis, + bool keep_dim, + bool reduce_all, + MetaTensor* out, + MetaConfig config = MetaConfig()); + +void ReduceIntArrayAxisInferMeta(const MetaTensor& x, + const IntArray& axis, + bool keep_dim, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void RepeatInterleaveInferMeta(const MetaTensor& x, int repeats, int dim, @@ -477,17 +490,19 @@ void StridedSliceInferMeta(const MetaTensor& x, MetaConfig config = MetaConfig()); void SumInferMeta(const MetaTensor& x, - const std::vector& axis, + const IntArray& axis, DataType dtype, bool keep_dim, - MetaTensor* out); + MetaTensor* out, + MetaConfig config = MetaConfig()); void SumRawInferMeta(const MetaTensor& x, - const std::vector& axis, + const IntArray& axis, bool keep_dim, bool reduce_all, DataType dtype, - MetaTensor* out); + MetaTensor* out, + MetaConfig config = MetaConfig()); void SvdInferMeta(const MetaTensor& x, bool full_matrices, diff --git a/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc b/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc index c7b1e3e51853b..912426a778d0c 100644 --- a/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc @@ -73,7 +73,7 @@ void CalculateXGrad(const Context& ctx, DenseTensor x_grad_out = phi::Sum( ctx, x_grad_v2, - reduce_idx, + phi::IntArray(reduce_idx), paddle::experimental::CppTypeToDataType::Type(), true); memcpy(x_grad, x_grad_out.data(), x_grad_out.numel() * sizeof(T)); @@ -131,7 +131,7 @@ void CalculateXGrad(const Context& ctx, DenseTensor x_grad_out = phi::Sum( ctx, x_grad_v2, - reduce_idx, + phi::IntArray(reduce_idx), paddle::experimental::CppTypeToDataType::Type(), true); memcpy(x_grad, x_grad_out.data(), x_grad_out.numel() * sizeof(T)); @@ -166,7 +166,7 @@ void CalculateXGrad(const Context& ctx, DenseTensor x_grad_out = phi::Sum( ctx, x_grad_v2, - reduce_idx, + phi::IntArray(reduce_idx), paddle::experimental::CppTypeToDataType::Type(), true); memcpy(x_grad, x_grad_out.data(), x_grad_out.numel() * sizeof(T)); @@ -220,7 +220,7 @@ void CalculateXGrad(const Context& ctx, DenseTensor x_grad_out = phi::Sum( ctx, x_grad_v2, - reduce_idx, + phi::IntArray(reduce_idx), paddle::experimental::CppTypeToDataType::Type(), true); memcpy(x_grad, x_grad_out.data(), x_grad_out.numel() * sizeof(T)); diff --git a/paddle/phi/kernels/cpu/graph_send_uv_grad_kernel.cc b/paddle/phi/kernels/cpu/graph_send_uv_grad_kernel.cc index 4e28acdad3db4..23e5172c3afa7 100644 --- a/paddle/phi/kernels/cpu/graph_send_uv_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/graph_send_uv_grad_kernel.cc @@ -86,7 +86,7 @@ void CalculateGrad(const Context& ctx, DenseTensor x_grad_out = phi::Sum( ctx, x_grad_v2, - reduce_idx, + phi::IntArray(reduce_idx), paddle::experimental::CppTypeToDataType::Type(), true); memcpy(x_grad, x_grad_out.data(), x_grad_out.numel() * sizeof(T)); @@ -148,7 +148,7 @@ void CalculateGrad(const Context& ctx, DenseTensor x_grad_out = phi::Sum( ctx, x_grad_v2, - reduce_idx, + phi::IntArray(reduce_idx), paddle::experimental::CppTypeToDataType::Type(), true); memcpy(x_grad, x_grad_out.data(), x_grad_out.numel() * sizeof(T)); diff --git a/paddle/phi/kernels/cpu/identity_loss_kernel.cc b/paddle/phi/kernels/cpu/identity_loss_kernel.cc index 941174eb5b0bd..5edf6d9e16024 100644 --- a/paddle/phi/kernels/cpu/identity_loss_kernel.cc +++ b/paddle/phi/kernels/cpu/identity_loss_kernel.cc @@ -31,7 +31,7 @@ void IdentityLossKernel(const Context& dev_ctx, case 0: // sum phi::SumRawKernel( - dev_ctx, x, std::vector{0}, false, true, out->dtype(), out); + dev_ctx, x, phi::IntArray({0}), false, true, out->dtype(), out); break; case 1: // mean diff --git a/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc b/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc index 4f941099c9d05..491e9c5d210cb 100644 --- a/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc +++ b/paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc @@ -118,7 +118,7 @@ void MatrixRankTolKernel(const Context& dev_ctx, dev_ctx.template Alloc(&max_eigenvalue_tensor); phi::MaxKernel(dev_ctx, eigenvalue_tensor, - std::vector{-1}, + phi::IntArray({-1}), false, &max_eigenvalue_tensor); diff --git a/paddle/phi/kernels/cpu/reduce_max_kernel.cc b/paddle/phi/kernels/cpu/reduce_max_kernel.cc index f9ea0aa0faf06..b15a555a2cf4d 100644 --- a/paddle/phi/kernels/cpu/reduce_max_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_max_kernel.cc @@ -24,13 +24,13 @@ namespace phi { template void MaxRawKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* out) { auto out_dtype = x.dtype(); phi::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); + dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); } } // namespace phi diff --git a/paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc index 77176d5d7469e..3ab8a40a85e55 100644 --- a/paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_mean_grad_kernel.cc @@ -24,12 +24,18 @@ template void ReduceMeanGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out_grad, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { - ReduceGradKernel( - dev_ctx, x, paddle::none, out_grad, dims, keep_dim, reduce_all, x_grad); + ReduceGradKernel(dev_ctx, + x, + paddle::none, + out_grad, + dims.GetData(), + keep_dim, + reduce_all, + x_grad); } } // namespace phi diff --git a/paddle/phi/kernels/cpu/reduce_mean_kernel.cc b/paddle/phi/kernels/cpu/reduce_mean_kernel.cc index 8fa687632f653..7164ec8b2bf99 100644 --- a/paddle/phi/kernels/cpu/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_mean_kernel.cc @@ -24,13 +24,13 @@ namespace phi { template void MeanRawKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* out) { auto out_dtype = x.dtype(); phi::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); + dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); } } // namespace phi diff --git a/paddle/phi/kernels/cpu/reduce_min_kernel.cc b/paddle/phi/kernels/cpu/reduce_min_kernel.cc index 0a241c81dbe69..a11de5ea81ab6 100644 --- a/paddle/phi/kernels/cpu/reduce_min_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_min_kernel.cc @@ -24,13 +24,13 @@ namespace phi { template void MinRawKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* out) { auto out_dtype = x.dtype(); phi::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); + dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); } } // namespace phi diff --git a/paddle/phi/kernels/cpu/reduce_prod_kernel.cc b/paddle/phi/kernels/cpu/reduce_prod_kernel.cc index d31a6e5626289..36766d27ed434 100644 --- a/paddle/phi/kernels/cpu/reduce_prod_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_prod_kernel.cc @@ -24,13 +24,13 @@ namespace phi { template void ProdRawKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* out) { auto out_dtype = x.dtype(); phi::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); + dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); } } // namespace phi diff --git a/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc b/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc index abc18b1c578a8..87e3df717b244 100644 --- a/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_sum_grad_kernel.cc @@ -73,7 +73,7 @@ template void ReduceSumGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out_grad, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { @@ -84,17 +84,25 @@ void ReduceSumGradKernel(const Context& dev_ctx, DenseTensor x_grad_tmp = phi::Empty(dev_ctx, std::move(x_grad_meta)); - ComputeFromInput(dev_ctx, x, out_grad, dims, &x_grad_tmp); + ComputeFromInput( + dev_ctx, x, out_grad, dims.GetData(), &x_grad_tmp); phi::CastKernel(dev_ctx, x_grad_tmp, x.dtype(), x_grad); } else { - ComputeFromInput(dev_ctx, x, out_grad, dims, x_grad); + ComputeFromInput( + dev_ctx, x, out_grad, dims.GetData(), x_grad); } } - ReduceGradKernel( - dev_ctx, x, paddle::none, out_grad, dims, keep_dim, reduce_all, x_grad); + ReduceGradKernel(dev_ctx, + x, + paddle::none, + out_grad, + dims.GetData(), + keep_dim, + reduce_all, + x_grad); } } // namespace phi diff --git a/paddle/phi/kernels/cpu/reduce_sum_kernel.cc b/paddle/phi/kernels/cpu/reduce_sum_kernel.cc index 0b4c4b9f4705a..95e5a1a214bb9 100644 --- a/paddle/phi/kernels/cpu/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_sum_kernel.cc @@ -24,7 +24,7 @@ namespace phi { template void SumRawKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DataType out_dtype, @@ -33,7 +33,7 @@ void SumRawKernel(const Context& dev_ctx, out_dtype = out->dtype(); } phi::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); + dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu b/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu index a1d522cc3d4d1..41667be1a9545 100644 --- a/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu @@ -158,7 +158,7 @@ void CalculateXGrad(const Context& ctx, DenseTensor x_grad_out = phi::Sum( ctx, x_grad_v2, - reduce_idx, + phi::IntArray(reduce_idx), paddle::experimental::CppTypeToDataType::Type(), true); #ifdef PADDLE_WITH_HIP @@ -235,7 +235,7 @@ void CalculateXGrad(const Context& ctx, DenseTensor x_grad_out = phi::Sum( ctx, x_grad_v2, - reduce_idx, + phi::IntArray(reduce_idx), paddle::experimental::CppTypeToDataType::Type(), true); #ifdef PADDLE_WITH_HIP @@ -281,7 +281,7 @@ void CalculateXGrad(const Context& ctx, DenseTensor x_grad_out = phi::Sum( ctx, x_grad_v2, - reduce_idx, + phi::IntArray(reduce_idx), paddle::experimental::CppTypeToDataType::Type(), true); #ifdef PADDLE_WITH_HIP @@ -349,7 +349,7 @@ void CalculateXGrad(const Context& ctx, DenseTensor x_grad_out = phi::Sum( ctx, x_grad_v2, - reduce_idx, + phi::IntArray(reduce_idx), paddle::experimental::CppTypeToDataType::Type(), true); // TODO(daisiming): Whether use x_grad instead. diff --git a/paddle/phi/kernels/gpu/graph_send_uv_grad_kernel.cu b/paddle/phi/kernels/gpu/graph_send_uv_grad_kernel.cu index d845e9cc4372a..1671fa7e17cd0 100644 --- a/paddle/phi/kernels/gpu/graph_send_uv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_send_uv_grad_kernel.cu @@ -108,7 +108,7 @@ void CalculateGrad(const Context& ctx, DenseTensor x_grad_out = phi::Sum( ctx, x_grad_v2, - reduce_idx, + phi::IntArray(reduce_idx), paddle::experimental::CppTypeToDataType::Type(), true); #ifdef PADDLE_WITH_HIP @@ -190,7 +190,7 @@ void CalculateGrad(const Context& ctx, DenseTensor x_grad_out = phi::Sum( ctx, x_grad_v2, - reduce_idx, + phi::IntArray(reduce_idx), paddle::experimental::CppTypeToDataType::Type(), true); #ifdef PADDLE_WITH_HIP diff --git a/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu b/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu index 5661d61c4e8c5..050c6d2faf535 100644 --- a/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu +++ b/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu @@ -378,7 +378,7 @@ void MatrixRankTolKernel(const Context& dev_ctx, phi::MaxKernel(dev_ctx, eigenvalue_tensor, - std::vector{-1}, + phi::IntArray({-1}), false, &max_eigenvalue_tensor); diff --git a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu index 50564a339ddc0..7da2502a5eea7 100644 --- a/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_mean_grad_kernel.cu @@ -25,13 +25,13 @@ template void ReduceMeanGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out_grad, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { int dim_size = x.dims().size(); std::vector reduce_dims = - funcs::details::GetReduceDim(dims, dim_size, reduce_all); + funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all); int reduce_num = 1; for (auto i : reduce_dims) { reduce_num *= (x.dims())[i]; @@ -41,7 +41,7 @@ void ReduceMeanGradKernel(const Context& dev_ctx, dev_ctx, x, out_grad, - dims, + dims.GetData(), keep_dim, reduce_all, x_grad, diff --git a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu index c0955cd7424ae..2230b4b8525b3 100644 --- a/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_sum_grad_kernel.cu @@ -25,7 +25,7 @@ template void ReduceSumGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out_grad, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { @@ -41,7 +41,7 @@ void ReduceSumGradKernel(const Context& dev_ctx, reduce_all = true; } std::vector reduce_dims = - funcs::details::GetReduceDim(dims, dim_size, reduce_all); + funcs::details::GetReduceDim(dims.GetData(), dim_size, reduce_all); auto update_dims = vectorize(d_x->dims()); int reduce_num = 1; diff --git a/paddle/phi/kernels/impl/einsum_grad_impl.h b/paddle/phi/kernels/impl/einsum_grad_impl.h index a04185a0c53ed..992b7572c1be5 100644 --- a/paddle/phi/kernels/impl/einsum_grad_impl.h +++ b/paddle/phi/kernels/impl/einsum_grad_impl.h @@ -80,7 +80,7 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx, if (to_reduce.size() != 0) { ret = Sum(dev_ctx, after_tile, - to_reduce, + phi::IntArray(to_reduce), after_tile.dtype(), false); // not keep dim. } else { diff --git a/paddle/phi/kernels/impl/einsum_impl.h b/paddle/phi/kernels/impl/einsum_impl.h index b5bc826881af8..80529c8b669aa 100644 --- a/paddle/phi/kernels/impl/einsum_impl.h +++ b/paddle/phi/kernels/impl/einsum_impl.h @@ -383,7 +383,8 @@ DenseTensor PerformReduction(const Context& dev_ctx, VLOG(5) << "call PerformReduction: with axis: " << paddle::string::join_strings(indices, ","); if (indices.size() == 0) return tensor; - return Sum(dev_ctx, tensor, indices, tensor.dtype(), true); + return Sum( + dev_ctx, tensor, phi::IntArray(indices), tensor.dtype(), true); } inline bool is_no_need_transpose(const std::vector& axis) { diff --git a/paddle/phi/kernels/impl/lstsq_kernel_impl.h b/paddle/phi/kernels/impl/lstsq_kernel_impl.h index 73ba954614a22..1644d0852a87d 100644 --- a/paddle/phi/kernels/impl/lstsq_kernel_impl.h +++ b/paddle/phi/kernels/impl/lstsq_kernel_impl.h @@ -69,8 +69,8 @@ inline void GetResidualsTensor(const DeviceContext& dev_ctx, dev_ctx.template Alloc(pow_tensor); phi::PowKernel(dev_ctx, sub_tensor, Scalar(2), pow_tensor); - auto sum_tensor = - phi::Sum(dev_ctx, *pow_tensor, {-2}, pow_tensor->dtype(), false); + auto sum_tensor = phi::Sum( + dev_ctx, *pow_tensor, phi::IntArray({-2}), pow_tensor->dtype(), false); phi::Copy( dev_ctx, sum_tensor, dev_ctx.GetPlace(), true, residuals); } else { diff --git a/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h b/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h index 83dd4a2b576bb..33730a3717781 100644 --- a/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/reduce_max_grad_kernel_impl.h @@ -25,12 +25,12 @@ void ReduceMaxGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out, const DenseTensor& out_grad, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { ReduceGradKernel( - dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); + dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad); } } // namespace phi diff --git a/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h b/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h index 592b5309cd970..93afa07ff01af 100644 --- a/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/reduce_min_grad_kernel_impl.h @@ -25,12 +25,12 @@ void ReduceMinGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out, const DenseTensor& out_grad, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { ReduceGradKernel( - dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); + dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad); } } // namespace phi diff --git a/paddle/phi/kernels/impl/reduce_prod_grad_kernel_impl.h b/paddle/phi/kernels/impl/reduce_prod_grad_kernel_impl.h index 69775281a259c..a6f92543cc9c6 100644 --- a/paddle/phi/kernels/impl/reduce_prod_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/reduce_prod_grad_kernel_impl.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/phi/common/int_array.h" #include "paddle/phi/kernels/funcs/reduce_functor.h" #include "paddle/phi/kernels/impl/reduce_grad.h" #include "paddle/phi/kernels/reduce_prod_grad_kernel.h" @@ -25,12 +26,12 @@ void ReduceProdGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out, const DenseTensor& out_grad, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* x_grad) { ReduceGradKernel( - dev_ctx, x, out, out_grad, dims, keep_dim, reduce_all, x_grad); + dev_ctx, x, out, out_grad, dims.GetData(), keep_dim, reduce_all, x_grad); } } // namespace phi diff --git a/paddle/phi/kernels/kps/reduce_max_kernel.cu b/paddle/phi/kernels/kps/reduce_max_kernel.cu index 52644849ad8bf..fb47b64f6ecec 100644 --- a/paddle/phi/kernels/kps/reduce_max_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_max_kernel.cu @@ -12,22 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/reduce_max_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/gpu/reduce.h" -#include "paddle/phi/kernels/reduce_max_kernel.h" namespace phi { template void MaxRawKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* out) { auto out_dtype = x.dtype(); phi::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); + dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); } } // namespace phi diff --git a/paddle/phi/kernels/kps/reduce_mean_kernel.cu b/paddle/phi/kernels/kps/reduce_mean_kernel.cu index c4ecd4380c306..7f7946e030063 100644 --- a/paddle/phi/kernels/kps/reduce_mean_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_mean_kernel.cu @@ -12,22 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/reduce_mean_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/gpu/reduce.h" -#include "paddle/phi/kernels/reduce_mean_kernel.h" namespace phi { template void MeanRawKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* out) { auto out_dtype = x.dtype(); phi::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out, true); + dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out, true); } } // namespace phi diff --git a/paddle/phi/kernels/kps/reduce_min_kernel.cu b/paddle/phi/kernels/kps/reduce_min_kernel.cu index 6fea48b588abb..9c3e61d3c0bc5 100644 --- a/paddle/phi/kernels/kps/reduce_min_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_min_kernel.cu @@ -12,22 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/reduce_min_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/gpu/reduce.h" -#include "paddle/phi/kernels/reduce_min_kernel.h" namespace phi { template void MinRawKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* out) { auto out_dtype = x.dtype(); phi::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); + dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); } } // namespace phi diff --git a/paddle/phi/kernels/kps/reduce_prod_kernel.cu b/paddle/phi/kernels/kps/reduce_prod_kernel.cu index 13d8e29b60b12..f5b52937e36fe 100644 --- a/paddle/phi/kernels/kps/reduce_prod_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_prod_kernel.cu @@ -12,22 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/reduce_prod_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/gpu/reduce.h" -#include "paddle/phi/kernels/reduce_prod_kernel.h" namespace phi { template void ProdRawKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* out) { auto out_dtype = x.dtype(); phi::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); + dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); } } // namespace phi diff --git a/paddle/phi/kernels/kps/reduce_sum_kernel.cu b/paddle/phi/kernels/kps/reduce_sum_kernel.cu index f219abd3348a6..f3d3246854f33 100644 --- a/paddle/phi/kernels/kps/reduce_sum_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_sum_kernel.cu @@ -12,16 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/phi/kernels/reduce_sum_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/gpu/reduce.h" -#include "paddle/phi/kernels/reduce_sum_kernel.h" namespace phi { template void SumRawKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DataType out_dtype, @@ -30,7 +30,7 @@ void SumRawKernel(const Context& dev_ctx, out_dtype = out->dtype(); } phi::Reduce( - dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out); + dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); } } // namespace phi diff --git a/paddle/phi/kernels/reduce_max_grad_kernel.h b/paddle/phi/kernels/reduce_max_grad_kernel.h index ef3d9f36d28de..d1522667935f6 100644 --- a/paddle/phi/kernels/reduce_max_grad_kernel.h +++ b/paddle/phi/kernels/reduce_max_grad_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -23,7 +24,7 @@ void ReduceMaxGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out, const DenseTensor& out_grad, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* x_grad); diff --git a/paddle/phi/kernels/reduce_max_kernel.cc b/paddle/phi/kernels/reduce_max_kernel.cc index 7bdf9ba2bbcc6..72dd515fc4321 100644 --- a/paddle/phi/kernels/reduce_max_kernel.cc +++ b/paddle/phi/kernels/reduce_max_kernel.cc @@ -22,7 +22,7 @@ namespace phi { template void MaxKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, DenseTensor* out) { bool reduce_all = false; diff --git a/paddle/phi/kernels/reduce_max_kernel.h b/paddle/phi/kernels/reduce_max_kernel.h index f224f494a7229..2af22a2ddde3d 100644 --- a/paddle/phi/kernels/reduce_max_kernel.h +++ b/paddle/phi/kernels/reduce_max_kernel.h @@ -14,13 +14,14 @@ #pragma once +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { template void MaxRawKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* out); @@ -28,7 +29,7 @@ void MaxRawKernel(const Context& dev_ctx, template void MaxKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, DenseTensor* out); diff --git a/paddle/phi/kernels/reduce_mean_grad_kernel.h b/paddle/phi/kernels/reduce_mean_grad_kernel.h index ccda3160aa9e5..572e5a0f6fb0c 100644 --- a/paddle/phi/kernels/reduce_mean_grad_kernel.h +++ b/paddle/phi/kernels/reduce_mean_grad_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -22,7 +23,7 @@ template void ReduceMeanGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out_grad, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* x_grad); diff --git a/paddle/phi/kernels/reduce_mean_kernel.cc b/paddle/phi/kernels/reduce_mean_kernel.cc index 599b7eca32110..4bb77ac974792 100644 --- a/paddle/phi/kernels/reduce_mean_kernel.cc +++ b/paddle/phi/kernels/reduce_mean_kernel.cc @@ -22,7 +22,7 @@ namespace phi { template void MeanKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, DenseTensor* out) { bool reduce_all = false; diff --git a/paddle/phi/kernels/reduce_mean_kernel.h b/paddle/phi/kernels/reduce_mean_kernel.h index 2ac4bd8a46e64..13c387a86e228 100644 --- a/paddle/phi/kernels/reduce_mean_kernel.h +++ b/paddle/phi/kernels/reduce_mean_kernel.h @@ -22,7 +22,7 @@ namespace phi { template void MeanRawKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* out); @@ -30,14 +30,14 @@ void MeanRawKernel(const Context& dev_ctx, template void MeanKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, DenseTensor* out); template DenseTensor Mean(const Context& dev_ctx, const DenseTensor& x, - const std::vector& axis, + const IntArray& axis, bool keep_dim) { DenseTensor dense_out; MetaTensor meta_out(&dense_out); diff --git a/paddle/phi/kernels/reduce_min_grad_kernel.h b/paddle/phi/kernels/reduce_min_grad_kernel.h index 3c6ea3a3564cf..c737761439089 100644 --- a/paddle/phi/kernels/reduce_min_grad_kernel.h +++ b/paddle/phi/kernels/reduce_min_grad_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -23,7 +24,7 @@ void ReduceMinGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out, const DenseTensor& out_grad, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* x_grad); diff --git a/paddle/phi/kernels/reduce_min_kernel.cc b/paddle/phi/kernels/reduce_min_kernel.cc index 69725759e4e82..11f11b772ef6f 100644 --- a/paddle/phi/kernels/reduce_min_kernel.cc +++ b/paddle/phi/kernels/reduce_min_kernel.cc @@ -22,7 +22,7 @@ namespace phi { template void MinKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, DenseTensor* out) { bool reduce_all = false; diff --git a/paddle/phi/kernels/reduce_min_kernel.h b/paddle/phi/kernels/reduce_min_kernel.h index bbf3f2ab81826..e6f133cc9ca00 100644 --- a/paddle/phi/kernels/reduce_min_kernel.h +++ b/paddle/phi/kernels/reduce_min_kernel.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -21,7 +22,7 @@ namespace phi { template void MinRawKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* out); @@ -29,7 +30,7 @@ void MinRawKernel(const Context& dev_ctx, template void MinKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, DenseTensor* out); diff --git a/paddle/phi/kernels/reduce_prod_grad_kernel.h b/paddle/phi/kernels/reduce_prod_grad_kernel.h index fbf9f19a1bb82..fb773f167f90b 100644 --- a/paddle/phi/kernels/reduce_prod_grad_kernel.h +++ b/paddle/phi/kernels/reduce_prod_grad_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -23,7 +24,7 @@ void ReduceProdGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out, const DenseTensor& out_grad, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* x_grad); diff --git a/paddle/phi/kernels/reduce_prod_kernel.cc b/paddle/phi/kernels/reduce_prod_kernel.cc index 3bb1c7552b11f..37f1f7bb8172e 100644 --- a/paddle/phi/kernels/reduce_prod_kernel.cc +++ b/paddle/phi/kernels/reduce_prod_kernel.cc @@ -22,7 +22,7 @@ namespace phi { template void ProdKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, DenseTensor* out) { bool reduce_all = false; diff --git a/paddle/phi/kernels/reduce_prod_kernel.h b/paddle/phi/kernels/reduce_prod_kernel.h index be46a554b57e1..91de087ccbc73 100644 --- a/paddle/phi/kernels/reduce_prod_kernel.h +++ b/paddle/phi/kernels/reduce_prod_kernel.h @@ -14,13 +14,14 @@ #pragma once +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { template void ProdRawKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* out); @@ -28,7 +29,7 @@ void ProdRawKernel(const Context& dev_ctx, template void ProdKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, DenseTensor* out); diff --git a/paddle/phi/kernels/reduce_sum_grad_kernel.h b/paddle/phi/kernels/reduce_sum_grad_kernel.h index b8b6618d43ec9..8cea1f4b34594 100644 --- a/paddle/phi/kernels/reduce_sum_grad_kernel.h +++ b/paddle/phi/kernels/reduce_sum_grad_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" namespace phi { @@ -22,7 +23,7 @@ template void ReduceSumGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& out_grad, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DenseTensor* x_grad); diff --git a/paddle/phi/kernels/reduce_sum_kernel.cc b/paddle/phi/kernels/reduce_sum_kernel.cc index c9622768c45d9..5fed4dbc44d99 100644 --- a/paddle/phi/kernels/reduce_sum_kernel.cc +++ b/paddle/phi/kernels/reduce_sum_kernel.cc @@ -22,7 +22,7 @@ namespace phi { template void SumKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, DataType out_dtype, bool keep_dim, DenseTensor* out) { diff --git a/paddle/phi/kernels/reduce_sum_kernel.h b/paddle/phi/kernels/reduce_sum_kernel.h index c969cea296db1..3bcf025d96bc4 100644 --- a/paddle/phi/kernels/reduce_sum_kernel.h +++ b/paddle/phi/kernels/reduce_sum_kernel.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/phi/common/int_array.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/infermeta/unary.h" @@ -21,7 +22,7 @@ namespace phi { template void SumRawKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, bool keep_dim, bool reduce_all, DataType out_dtype, @@ -30,7 +31,7 @@ void SumRawKernel(const Context& dev_ctx, template void SumKernel(const Context& dev_ctx, const DenseTensor& x, - const std::vector& dims, + const IntArray& dims, DataType out_dtype, bool keep_dim, DenseTensor* out); @@ -38,7 +39,7 @@ void SumKernel(const Context& dev_ctx, template DenseTensor Sum(const Context& dev_ctx, const DenseTensor& x, - const std::vector& axis, + const IntArray& axis, DataType dtype, bool keep_dim) { DenseTensor dense_out; diff --git a/paddle/phi/tests/kernels/test_sum_dev_api.cc b/paddle/phi/tests/kernels/test_sum_dev_api.cc index 20e934eb69297..b870861eaad2d 100644 --- a/paddle/phi/tests/kernels/test_sum_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sum_dev_api.cc @@ -51,8 +51,8 @@ TEST(DEV_API, sum) { .get()); // 2. test API - auto out = - phi::Sum(dev_ctx, dense_x, axis, phi::DataType::FLOAT32, false); + auto out = phi::Sum( + dev_ctx, dense_x, phi::IntArray(axis), phi::DataType::FLOAT32, false); // 3. check result ASSERT_EQ(out.dims().size(), 1); diff --git a/python/paddle/fluid/tests/unittests/test_max_op.py b/python/paddle/fluid/tests/unittests/test_max_op.py index dc11d78699e73..235eef5d90983 100644 --- a/python/paddle/fluid/tests/unittests/test_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_max_op.py @@ -14,12 +14,16 @@ from __future__ import print_function +import os import unittest +import tempfile import numpy as np from op_test import OpTest, skip_check_grad_ci, check_out_dtype import paddle from paddle.fluid.framework import _test_eager_guard import paddle.fluid.core as core +import paddle.inference as paddle_infer +from test_sum_op import TestReduceOPTensorAxisBase class ApiMaxTest(unittest.TestCase): @@ -70,15 +74,6 @@ def test_input_type(): self.assertRaises(TypeError, test_input_type) - def test_axis_type(): - with paddle.static.program_guard(paddle.static.Program(), - paddle.static.Program()): - data = paddle.static.data("data", shape=[10, 10], dtype="int64") - axis = paddle.static.data("axis", shape=[10, 10], dtype="int64") - result_min = paddle.min(data, axis) - - self.assertRaises(TypeError, test_axis_type) - def test_imperative_api(self): paddle.disable_static() np_x = np.array([10, 10]).astype('float64') @@ -124,5 +119,29 @@ def test_max(self): expect_dtypes=['float32', 'float64', 'int32', 'int64']) +class TestMaxWithTensorAxis1(TestReduceOPTensorAxisBase): + + def init_data(self): + self.pd_api = paddle.max + self.np_api = np.max + self.x = paddle.randn([10, 5, 9, 9], dtype='float64') + self.np_axis = np.array([1, 2], dtype='int64') + self.tensor_axis = paddle.to_tensor([1, 2], dtype='int64') + + +class TestMaxWithTensorAxis2(TestReduceOPTensorAxisBase): + + def init_data(self): + self.pd_api = paddle.max + self.np_api = np.max + self.x = paddle.randn([10, 10, 9, 9], dtype='float64') + self.np_axis = np.array([0, 1, 2], dtype='int64') + self.tensor_axis = [ + 0, + paddle.to_tensor([1], 'int64'), + paddle.to_tensor([2], 'int64') + ] + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_mean_op.py b/python/paddle/fluid/tests/unittests/test_mean_op.py index 965d70a59503f..a4ed6dc3b575c 100644 --- a/python/paddle/fluid/tests/unittests/test_mean_op.py +++ b/python/paddle/fluid/tests/unittests/test_mean_op.py @@ -22,6 +22,7 @@ import paddle.fluid as fluid from paddle.fluid import Program, program_guard from paddle.fluid.framework import _test_eager_guard +from test_sum_op import TestReduceOPTensorAxisBase np.random.seed(10) @@ -408,6 +409,30 @@ def test_errors(self): self.assertRaises(TypeError, paddle.mean, x) +class TestMeanWithTensorAxis1(TestReduceOPTensorAxisBase): + + def init_data(self): + self.pd_api = paddle.mean + self.np_api = np.mean + self.x = paddle.randn([10, 5, 9, 9], dtype='float64') + self.np_axis = np.array([1, 2], dtype='int64') + self.tensor_axis = paddle.to_tensor([1, 2], dtype='int64') + + +class TestMeanWithTensorAxis2(TestReduceOPTensorAxisBase): + + def init_data(self): + self.pd_api = paddle.mean + self.np_api = np.mean + self.x = paddle.randn([10, 10, 9, 9], dtype='float64') + self.np_axis = np.array([0, 1, 2], dtype='int64') + self.tensor_axis = [ + 0, + paddle.to_tensor([1], 'int64'), + paddle.to_tensor([2], 'int64') + ] + + if __name__ == "__main__": paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_min_op.py b/python/paddle/fluid/tests/unittests/test_min_op.py index 6e5f9d1321593..2daa813997154 100644 --- a/python/paddle/fluid/tests/unittests/test_min_op.py +++ b/python/paddle/fluid/tests/unittests/test_min_op.py @@ -14,12 +14,16 @@ from __future__ import print_function +import os import unittest +import tempfile import numpy as np from op_test import OpTest, skip_check_grad_ci, check_out_dtype import paddle import paddle.fluid.core as core from paddle.fluid.framework import _test_eager_guard +import paddle.inference as paddle_infer +from test_sum_op import TestReduceOPTensorAxisBase class ApiMinTest(unittest.TestCase): @@ -70,15 +74,6 @@ def test_input_type(): self.assertRaises(TypeError, test_input_type) - def test_axis_type(): - with paddle.static.program_guard(paddle.static.Program(), - paddle.static.Program()): - data = paddle.static.data("data", shape=[10, 10], dtype="int64") - axis = paddle.static.data("axis", shape=[10, 10], dtype="int64") - result_min = paddle.min(data, axis) - - self.assertRaises(TypeError, test_axis_type) - def test_imperative_api(self): paddle.disable_static() np_x = np.array([10, 10]).astype('float64') @@ -103,5 +98,30 @@ def test_min(self): expect_dtypes=['float32', 'float64', 'int32', 'int64']) +class TestMinWithTensorAxis1(TestReduceOPTensorAxisBase): + + def init_data(self): + self.pd_api = paddle.min + self.np_api = np.min + self.x = paddle.randn([10, 5, 9, 9], dtype='float64') + self.np_axis = np.array([1, 2], dtype='int64') + self.tensor_axis = paddle.to_tensor([1, 2], dtype='int64') + + +class TestMinWithTensorAxis2(TestReduceOPTensorAxisBase): + + def init_data(self): + self.pd_api = paddle.min + self.np_api = np.min + self.x = paddle.randn([10, 10, 9, 9], dtype='float64') + self.np_axis = np.array([0, 1, 2], dtype='int64') + self.tensor_axis = [ + 0, + paddle.to_tensor([1], 'int64'), + paddle.to_tensor([2], 'int64') + ] + self.keepdim = True + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_prod_op.py b/python/paddle/fluid/tests/unittests/test_prod_op.py index 19611a933cc29..dde612ea746c5 100644 --- a/python/paddle/fluid/tests/unittests/test_prod_op.py +++ b/python/paddle/fluid/tests/unittests/test_prod_op.py @@ -17,6 +17,7 @@ import paddle import unittest import numpy as np +from test_sum_op import TestReduceOPTensorAxisBase class TestProdOp(unittest.TestCase): @@ -168,5 +169,29 @@ def test_error(self): self.assertRaises(TypeError, paddle.prod, x, 'bool') +class TestProdWithTensorAxis1(TestReduceOPTensorAxisBase): + + def init_data(self): + self.pd_api = paddle.prod + self.np_api = np.prod + self.x = paddle.randn([10, 5, 9, 9], dtype='float64') + self.np_axis = np.array([1, 2], dtype='int64') + self.tensor_axis = paddle.to_tensor([1, 2], dtype='int64') + + +class TestProdWithTensorAxis2(TestReduceOPTensorAxisBase): + + def init_data(self): + self.pd_api = paddle.prod + self.np_api = np.prod + self.x = paddle.randn([10, 10, 9, 9], dtype='float64') + self.np_axis = np.array([0, 1, 2], dtype='int64') + self.tensor_axis = [ + 0, + paddle.to_tensor([1], 'int64'), + paddle.to_tensor([2], 'int64') + ] + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sum_op.py b/python/paddle/fluid/tests/unittests/test_sum_op.py index e327c335b0dd0..b143af2ac50c3 100644 --- a/python/paddle/fluid/tests/unittests/test_sum_op.py +++ b/python/paddle/fluid/tests/unittests/test_sum_op.py @@ -14,7 +14,9 @@ from __future__ import print_function +import os import unittest +import tempfile import numpy as np from op_test import OpTest import paddle @@ -27,6 +29,7 @@ convert_uint16_to_float) from paddle import _C_ops, _legacy_C_ops from paddle.fluid.framework import _test_eager_guard +import paddle.inference as paddle_infer class TestSumOp(OpTest): @@ -483,6 +486,100 @@ def test_list_of_none_input(): create_test_sum_fp16_class(TestSelectedRowsSumOp) create_test_sum_fp16_class(TestLoDTensorAndSelectedRowsOp) + +class TestReduceOPTensorAxisBase(unittest.TestCase): + + def setUp(self): + paddle.disable_static() + paddle.seed(2022) + self.temp_dir = tempfile.TemporaryDirectory() + self.save_path = os.path.join(self.temp_dir.name, 'reduce_tensor_axis') + self.place = paddle.CUDAPlace( + 0) if paddle.is_compiled_with_cuda() else paddle.CPUPlace() + self.keepdim = False + self.init_data() + + def tearDwon(self): + self.temp_dir.cleanup() + + def init_data(self): + self.pd_api = paddle.sum + self.np_api = np.sum + self.x = paddle.randn([10, 5, 9, 9], dtype='float64') + self.np_axis = np.array((1, 2), dtype='int64') + self.tensor_axis = paddle.to_tensor(self.np_axis, dtype='int64') + + def test_dygraph(self): + self.x.stop_gradient = False + pd_out = self.pd_api(self.x, self.tensor_axis) + np_out = self.np_api(self.x.numpy(), tuple(self.np_axis)) + np.testing.assert_allclose( + pd_out.numpy() if pd_out.size > 1 else pd_out.item(), np_out) + pd_out.backward() + self.assertEqual(self.x.gradient().shape, tuple(self.x.shape)) + + def test_static_and_infer(self): + paddle.enable_static() + main_prog = paddle.static.Program() + starup_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, starup_prog): + # run static + x = paddle.static.data(shape=self.x.shape, + name='x', + dtype='float32') + if isinstance(self.tensor_axis, paddle.Tensor): + axis = paddle.assign(self.np_axis) + else: + axis = [] + for i, item in enumerate(self.tensor_axis): + if isinstance(item, int): + axis.append(item) + else: + axis.append(paddle.full([1], self.np_axis[i], 'int64')) + + linear = paddle.nn.Linear(x.shape[-1], 5) + linear_out = linear(x) + out = self.pd_api(linear_out, axis, keepdim=self.keepdim) + exe = paddle.static.Executor(self.place) + exe.run(starup_prog) + static_out = exe.run(feed={'x': self.x.numpy().astype('float32')}, + fetch_list=[out]) + + # run infer + paddle.static.save_inference_model(self.save_path, [x], [out], exe) + config = paddle_infer.Config(self.save_path + '.pdmodel', + self.save_path + '.pdiparams') + if paddle.is_compiled_with_cuda(): + config.enable_use_gpu(100, 0) + else: + config.disable_gpu() + predictor = paddle_infer.create_predictor(config) + input_names = predictor.get_input_names() + input_handle = predictor.get_input_handle(input_names[0]) + fake_input = self.x.numpy().astype('float32') + input_handle.reshape(self.x.shape) + input_handle.copy_from_cpu(fake_input) + predictor.run() + output_names = predictor.get_output_names() + output_handle = predictor.get_output_handle(output_names[0]) + infer_out = output_handle.copy_to_cpu() + np.testing.assert_allclose(static_out[0], infer_out) + + +class TestSumWithTensorAxis1(TestReduceOPTensorAxisBase): + + def init_data(self): + self.pd_api = paddle.sum + self.np_api = np.sum + self.x = paddle.randn([10, 5, 9, 9], dtype='float64') + self.np_axis = np.array([0, 1, 2], dtype='int64') + self.tensor_axis = [ + 0, + paddle.to_tensor([1], 'int64'), + paddle.to_tensor([2], 'int64') + ] + + if __name__ == "__main__": enable_static() unittest.main() diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 0bad545025c1c..3acd9d5897aa1 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -34,6 +34,7 @@ from ..framework import _varbase_creator, convert_np_dtype_to_dtype_ from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype from ..fluid.dygraph.inplace_utils import inplace_apis_in_dygraph_only +from ..fluid.layers import utils # TODO: define math functions # yapf: disable @@ -1144,11 +1145,22 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): out8 = paddle.sum(x, axis=0) # [1, 1, 1, 1] out9 = paddle.sum(x, axis=1) # [4, 0] """ - if axis is not None and not isinstance(axis, (list, tuple)): - axis = [axis] + if isinstance(axis, Variable): + reduce_all_flag = True if axis.shape[0] == len(x.shape) else False + else: + if axis is not None and not isinstance(axis, (list, tuple)): + axis = [axis] - if not axis: - axis = [] + if not axis: + axis = [] + + if len(axis) == 0: + reduce_all_flag = True + else: + if len(axis) == len(x.shape): + reduce_all_flag = True + else: + reduce_all_flag = False dtype_flag = False if dtype is not None: @@ -1158,16 +1170,12 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): if in_dygraph_mode(): return _C_ops.sum(x, axis, dtype, keepdim) - if len(axis) == 0: - reduce_all_flag = True - else: - if len(axis) == len(x.shape): - reduce_all_flag = True - else: - reduce_all_flag = False + if not isinstance(axis, Variable): + axis = axis if axis != None and axis != [] and axis != () else [0] + if utils._contain_var(axis): + axis = utils._convert_to_tensor_list(axis) if _in_legacy_dygraph(): - axis = axis if axis != None and axis != [] else [0] if dtype_flag: return _legacy_C_ops.reduce_sum(x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all_flag, 'in_dtype', @@ -1177,7 +1185,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): 'reduce_all', reduce_all_flag) attrs = { - 'dim': axis if axis != None and axis != [] and axis != () else [0], + 'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all_flag } @@ -1194,7 +1202,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None): u'bool', u'float16', u'float32', u'float64', u'int32', u'int64', u'complex64', u'complex128'], 'sum') - check_type(axis, 'axis', (int, list, tuple, type(None)), 'sum') + check_type(axis, 'axis', (int, list, tuple, type(None), Variable), 'sum') helper = LayerHelper('sum', **locals()) if dtype_flag: @@ -2058,6 +2066,11 @@ def _get_reduce_axis(axis): axis = [] return reduce_all, axis +def _get_reduce_axis_with_tensor(axis): + if isinstance(axis, Variable): + return False, axis + return _get_reduce_axis(axis) + def _get_reduce_all_value(axis): """ Internal function for max, min, amax and amin. @@ -2154,7 +2167,7 @@ def max(x, axis=None, keepdim=False, name=None): #[7., 8.], [[[0., 0.], [0., 0.]], [[0., 0.], [1., 1.]]] """ - reduce_all, axis = _get_reduce_axis(axis) + reduce_all, axis = _get_reduce_axis_with_tensor(axis) if in_dygraph_mode(): return _C_ops.max(x, axis, keepdim) if _in_legacy_dygraph(): @@ -2164,6 +2177,8 @@ def max(x, axis=None, keepdim=False, name=None): helper = LayerHelper('max', **locals()) check_variable_and_dtype( x, 'x', ['float32', 'float64', 'int32', 'int64'], 'max') + if not isinstance(axis, Variable) and utils._contain_var(axis): + axis = utils._convert_to_tensor_list(axis) out = helper.create_variable_for_type_inference( dtype=x.dtype) @@ -2255,7 +2270,7 @@ def min(x, axis=None, keepdim=False, name=None): #[1., 2.], [[[1., 1.], [0., 0.]], [[0., 0.], [0., 0.]]] """ - reduce_all, axis = _get_reduce_axis(axis) + reduce_all, axis = _get_reduce_axis_with_tensor(axis) if in_dygraph_mode(): return _C_ops.min(x, axis, keepdim) @@ -2266,6 +2281,8 @@ def min(x, axis=None, keepdim=False, name=None): helper = LayerHelper('min', **locals()) check_variable_and_dtype( x, 'x', ['float32', 'float64', 'int32', 'int64'], 'min') + if not isinstance(axis, Variable) and utils._contain_var(axis): + axis = utils._convert_to_tensor_list(axis) out = helper.create_variable_for_type_inference( dtype=x.dtype) @@ -3369,19 +3386,22 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None): x = cast(x, dtype) dim = axis - if dim is not None and not isinstance(dim, list): - if isinstance(dim, tuple): - dim = list(dim) - elif isinstance(dim, int): - dim = [dim] - else: - raise TypeError( - "The type of axis must be int, list or tuple, but received {}". - format(type(dim))) + if isinstance(dim, Variable): + reduce_all = True if axis.shape[0] == len(x.shape) else False + else: + if dim is not None and not isinstance(dim, list): + if isinstance(dim, tuple): + dim = list(dim) + elif isinstance(dim, int): + dim = [dim] + else: + raise TypeError( + "The type of axis must be int, list or tuple, but received {}". + format(type(dim))) - reduce_all = True if dim is None or len(dim) == 0 or len(dim) == len(x.shape) else False - if dim is None or len(dim) == 0: - dim = [0] + reduce_all = True if dim is None or len(dim) == 0 or len(dim) == len(x.shape) else False + if dim is None or len(dim) == 0: + dim = [0] if in_dygraph_mode(): return _C_ops.reduce_prod(x, dim, keepdim, reduce_all) @@ -3393,6 +3413,8 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None): check_variable_and_dtype( x, 'x/input', ['float32', 'float64', 'int32', 'int64'], 'reduce_prod') out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) + if not isinstance(dim, Variable) and utils._contain_var(dim): + dim = utils._convert_to_tensor_list(dim) helper.append_op( type='reduce_prod', inputs={'X': x}, diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 043449cd6d81d..b3e14784c3d1c 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -21,6 +21,7 @@ from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode from .search import where from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype +from ..fluid.layers import utils import paddle from paddle import _C_ops, _legacy_C_ops @@ -80,17 +81,20 @@ def mean(x, axis=None, keepdim=False, name=None): # [ 8.5 12.5 16.5] """ - if isinstance(axis, int): - axis = [axis] - reduce_all = True if axis is None \ - or len(axis)==0 \ - or len(axis) == len(x.shape) else False - if axis is None or len(axis) == 0: - axis = [0] + if isinstance(axis, Variable): + reduce_all = True if axis.shape[0] == len(x.shape) else False + else: + if isinstance(axis, int): + axis = [axis] + reduce_all = True if axis is None \ + or len(axis)==0 \ + or len(axis) == len(x.shape) else False + if axis is None or len(axis) == 0: + axis = [0] if in_dygraph_mode(): if reduce_all: - axis = range(len(x.shape)) + axis = list(range(len(x.shape))) return _C_ops.mean(x, axis, keepdim) if _in_legacy_dygraph(): return _legacy_C_ops.reduce_mean(x, 'dim', axis, 'keep_dim', keepdim, @@ -99,12 +103,17 @@ def mean(x, axis=None, keepdim=False, name=None): check_variable_and_dtype(x, 'x/input', ['uint16', 'float16', 'float32', 'float64'], 'mean/reduce_mean') - check_type(axis, 'axis/dim', (int, list, tuple), 'mean/reduce_mean') + check_type(axis, 'axis/dim', (int, list, tuple, Variable), + 'mean/reduce_mean') if isinstance(axis, (list, tuple)): for item in axis: - check_type(item, 'elements of axis/dim', (int), 'mean/reduce_mean') + check_type(item, 'elements of axis/dim', (int, Variable), + 'mean/reduce_mean') helper = LayerHelper('mean', **locals()) + + if not isinstance(axis, Variable) and utils._contain_var(axis): + axis = utils._convert_to_tensor_list(axis) attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all} out = helper.create_variable_for_type_inference(x.dtype) helper.append_op(type='reduce_mean',