From e4e94a889a7e172ca92b9d0c4aca8c3c08a39fea Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Wed, 1 Feb 2023 17:01:31 +0800 Subject: [PATCH] [Zero-Dim] Fix 0-dim tensor for arg_min_max op. (#49570) * fix 0-d tensor for arg_min_max op. * fix xpu. * fix zero dims * fix * Update arg_min_max_kernel.cc * Update arg_min_max_kernel.cc * Update arg_min_max_kernel.cc * Update test_zero_dim_tensor.py * Update test_zero_dim_tensor_xpu.py * Update test_zero_dim_tensor.py * Update arg_min_max_kernel.cc * Update arg_min_max_kernel.cc * Update arg_min_max_kernel.cc --- paddle/phi/infermeta/unary.cc | 47 +++++++++++++------ paddle/phi/kernels/cpu/arg_min_max_kernel.cc | 6 +++ paddle/phi/kernels/gpu/arg_min_max_kernel.cu | 7 +++ paddle/phi/kernels/xpu/arg_min_max_kernel.cc | 9 ++++ .../tests/unittests/test_zero_dim_tensor.py | 15 ++++-- .../unittests/xpu/test_zero_dim_tensor_xpu.py | 5 +- 6 files changed, 68 insertions(+), 21 deletions(-) diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index eb05437ada8a5..2b35545db1cd8 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -160,22 +160,34 @@ void ArgMinMaxInferMeta(const MetaTensor& x, auto int_axis = axis.to(); const auto& x_dims = x.dims(); - PADDLE_ENFORCE_GE( - int_axis, - -x_dims.size(), - phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to" - " -Rank(X)(%d).", - int_axis, - -x_dims.size())); - PADDLE_ENFORCE_LT(int_axis, - x_dims.size(), - phi::errors::InvalidArgument( - "'axis'(%d) must be less than Rank(X)(%d) of Input(X).", - int_axis, - x_dims.size())); + auto x_rank = x.dims().size(); + if (x_rank > 0) { + PADDLE_ENFORCE_GE(int_axis, + -x_rank, + phi::errors::InvalidArgument( + "'axis'(%d) must be greater than or equal to" + " -Rank(X)(%d).", + int_axis, + -x_rank)); + PADDLE_ENFORCE_LT( + int_axis, + x_rank, + phi::errors::InvalidArgument( + "'axis'(%d) must be less than Rank(X)(%d) of Input(X).", + int_axis, + x_rank)); + } else { + // 0-dim tensor + PADDLE_ENFORCE_EQ((int_axis == 0 || int_axis == -1) && flatten, + true, + phi::errors::InvalidArgument( + "'axis'(%d) must be 0 or -1 if input tensor is " + "0-dim. and flatten should be true.", + int_axis)); + } - auto x_rank = x_dims.size(); if (int_axis < 0) int_axis += x_rank; + if (config.is_runtime) { if (dtype == phi::TransToProtoVarType(DataType::INT32)) { int64_t all_element_num = 0; @@ -195,8 +207,12 @@ void ArgMinMaxInferMeta(const MetaTensor& x, INT_MAX)); } } + std::vector vec; - if (flatten) { + + if (x_rank == 0) { + // vec is set to empty + } else if (flatten) { vec.emplace_back(static_cast(1)); } else { for (int64_t i = 0; i < int_axis; i++) vec.emplace_back(x_dims[i]); @@ -205,6 +221,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x, } for (int64_t i = int_axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]); } + out->set_dims(phi::make_ddim(vec)); if (dtype == 2) { out->set_dtype(DataType::INT32); diff --git a/paddle/phi/kernels/cpu/arg_min_max_kernel.cc b/paddle/phi/kernels/cpu/arg_min_max_kernel.cc index 61d20ac32f15a..694698050a0c0 100644 --- a/paddle/phi/kernels/cpu/arg_min_max_kernel.cc +++ b/paddle/phi/kernels/cpu/arg_min_max_kernel.cc @@ -96,6 +96,12 @@ struct VisitDataArgMinMaxFunctor { if (axis < 0) new_axis = axis + x_dims.size(); } + // For 0D Tensor + if (x.dims().size() == 0) { + phi::funcs::set_constant(dev_ctx, out, 0); + return; + } + #define CALL_ARG_MINMAX_FUNCTOR(rank) \ ArgMinMaxFunctor functor##rank; \ functor##rank(dev_ctx, x, out, x_dims, new_axis, new_keepdims) diff --git a/paddle/phi/kernels/gpu/arg_min_max_kernel.cu b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu index affd36a95ef8b..199ecc8e5b989 100644 --- a/paddle/phi/kernels/gpu/arg_min_max_kernel.cu +++ b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu @@ -30,6 +30,7 @@ namespace cub = hipcub; #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { namespace { // NOLINT @@ -180,6 +181,12 @@ struct VisitDataCudaArgMinMaxFunctor { x_dims = x.dims(); if (axis < 0) new_axis = axis + x.dims().size(); } + // For 0D Tensor + if (x.dims().size() == 0) { + dev_ctx.template Alloc(out); + phi::funcs::set_constant(dev_ctx, out, 0); + return; + } int64_t numel = x.numel(); int64_t groups = numel / x_dims[new_axis]; diff --git a/paddle/phi/kernels/xpu/arg_min_max_kernel.cc b/paddle/phi/kernels/xpu/arg_min_max_kernel.cc index 3513b64bc600e..ebf13142345ce 100644 --- a/paddle/phi/kernels/xpu/arg_min_max_kernel.cc +++ b/paddle/phi/kernels/xpu/arg_min_max_kernel.cc @@ -18,6 +18,7 @@ #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -39,7 +40,15 @@ void ArgMaxKernel(const Context& dev_ctx, DataType::INT64, DataType::INT32, dtype)); + // TODO(ZHUI): fix dtype of out dev_ctx.template Alloc(out); + if (x.dims().size() == 0) { + xpu::constant(dev_ctx.x_context(), + out->data(), + x.numel(), + static_cast(0)); + return; + } DDim x_dims; int axis_val = axis.to(); diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 2d07ab31334df..fcc171674deab 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -189,6 +189,8 @@ def test_static_unary(self): paddle.logsumexp, paddle.all, paddle.any, + paddle.argmax, + paddle.argmin, ] @@ -208,12 +210,13 @@ def test_dygraph_reduce(self): out.retain_grads() out.backward() - out_empty_list = api(x, []) - self.assertEqual(out_empty_list, out) - self.assertEqual(x.shape, []) self.assertEqual(out.shape, []) - np.testing.assert_allclose(out.numpy(), x.numpy()) + if api not in [paddle.argmax, paddle.argmin]: + np.testing.assert_allclose(out.numpy(), x.numpy()) + out_empty_list = api(x, []) + self.assertEqual(out_empty_list, out) + if x.grad is not None: self.assertEqual(x.grad.shape, []) self.assertEqual(out.grad.shape, []) @@ -250,7 +253,9 @@ def test_static_reduce(self): res = exe.run(main_prog, fetch_list=fetch_list) self.assertEqual(res[0].shape, ()) self.assertEqual(res[1].shape, ()) - np.testing.assert_allclose(res[0], res[1]) + if api not in [paddle.argmax, paddle.argmin]: + np.testing.assert_allclose(res[0], res[1]) + if len(res) > 2: self.assertEqual(res[2].shape, ()) self.assertEqual(res[3].shape, ()) diff --git a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py index f6f64aefe9db7..35e98e3cdaa75 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py @@ -132,6 +132,8 @@ def test_dygraph_unary(self): paddle.logsumexp, paddle.all, paddle.any, + paddle.argmax, + paddle.argmin, ] @@ -153,7 +155,8 @@ def test_dygraph_reduce(self): self.assertEqual(x.shape, []) self.assertEqual(out.shape, []) - np.testing.assert_allclose(out.numpy(), x.numpy()) + if api not in [paddle.argmax, paddle.argmin]: + np.testing.assert_allclose(out.numpy(), x.numpy()) if x.grad is not None: self.assertEqual(x.grad.shape, []) self.assertEqual(out.grad.shape, [])