Skip to content

Commit

Permalink
[Zero-Dim] Fix 0-dim tensor for arg_min_max op. (#49570)
Browse files Browse the repository at this point in the history
* fix 0-d tensor for arg_min_max op.

* fix xpu.

* fix zero dims

* fix

* Update arg_min_max_kernel.cc

* Update arg_min_max_kernel.cc

* Update arg_min_max_kernel.cc

* Update test_zero_dim_tensor.py

* Update test_zero_dim_tensor_xpu.py

* Update test_zero_dim_tensor.py

* Update arg_min_max_kernel.cc

* Update arg_min_max_kernel.cc

* Update arg_min_max_kernel.cc
  • Loading branch information
ZHUI committed Feb 1, 2023
1 parent 71f247b commit e4e94a8
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 21 deletions.
47 changes: 32 additions & 15 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,22 +160,34 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
auto int_axis = axis.to<int64_t>();
const auto& x_dims = x.dims();

PADDLE_ENFORCE_GE(
int_axis,
-x_dims.size(),
phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
int_axis,
-x_dims.size()));
PADDLE_ENFORCE_LT(int_axis,
x_dims.size(),
phi::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
int_axis,
x_dims.size()));
auto x_rank = x.dims().size();
if (x_rank > 0) {
PADDLE_ENFORCE_GE(int_axis,
-x_rank,
phi::errors::InvalidArgument(
"'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
int_axis,
-x_rank));
PADDLE_ENFORCE_LT(
int_axis,
x_rank,
phi::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
int_axis,
x_rank));
} else {
// 0-dim tensor
PADDLE_ENFORCE_EQ((int_axis == 0 || int_axis == -1) && flatten,
true,
phi::errors::InvalidArgument(
"'axis'(%d) must be 0 or -1 if input tensor is "
"0-dim. and flatten should be true.",
int_axis));
}

auto x_rank = x_dims.size();
if (int_axis < 0) int_axis += x_rank;

if (config.is_runtime) {
if (dtype == phi::TransToProtoVarType(DataType::INT32)) {
int64_t all_element_num = 0;
Expand All @@ -195,8 +207,12 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
INT_MAX));
}
}

std::vector<int64_t> vec;
if (flatten) {

if (x_rank == 0) {
// vec is set to empty
} else if (flatten) {
vec.emplace_back(static_cast<int64_t>(1));
} else {
for (int64_t i = 0; i < int_axis; i++) vec.emplace_back(x_dims[i]);
Expand All @@ -205,6 +221,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
}
for (int64_t i = int_axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
}

out->set_dims(phi::make_ddim(vec));
if (dtype == 2) {
out->set_dtype(DataType::INT32);
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/kernels/cpu/arg_min_max_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ struct VisitDataArgMinMaxFunctor {
if (axis < 0) new_axis = axis + x_dims.size();
}

// For 0D Tensor
if (x.dims().size() == 0) {
phi::funcs::set_constant(dev_ctx, out, 0);
return;
}

#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<Context, T, Tout, rank, EnumArgMinMaxValue> functor##rank; \
functor##rank(dev_ctx, x, out, x_dims, new_axis, new_keepdims)
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/gpu/arg_min_max_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace cub = hipcub;

#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {

namespace { // NOLINT
Expand Down Expand Up @@ -180,6 +181,12 @@ struct VisitDataCudaArgMinMaxFunctor {
x_dims = x.dims();
if (axis < 0) new_axis = axis + x.dims().size();
}
// For 0D Tensor
if (x.dims().size() == 0) {
dev_ctx.template Alloc<IndType>(out);
phi::funcs::set_constant(dev_ctx, out, 0);
return;
}

int64_t numel = x.numel();
int64_t groups = numel / x_dims[new_axis];
Expand Down
9 changes: 9 additions & 0 deletions paddle/phi/kernels/xpu/arg_min_max_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

Expand All @@ -39,7 +40,15 @@ void ArgMaxKernel(const Context& dev_ctx,
DataType::INT64,
DataType::INT32,
dtype));
// TODO(ZHUI): fix dtype of out
dev_ctx.template Alloc<int64_t>(out);
if (x.dims().size() == 0) {
xpu::constant(dev_ctx.x_context(),
out->data<int64_t>(),
x.numel(),
static_cast<int64_t>(0));
return;
}

DDim x_dims;
int axis_val = axis.to<int>();
Expand Down
15 changes: 10 additions & 5 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def test_static_unary(self):
paddle.logsumexp,
paddle.all,
paddle.any,
paddle.argmax,
paddle.argmin,
]


Expand All @@ -208,12 +210,13 @@ def test_dygraph_reduce(self):
out.retain_grads()
out.backward()

out_empty_list = api(x, [])
self.assertEqual(out_empty_list, out)

self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
np.testing.assert_allclose(out.numpy(), x.numpy())
if api not in [paddle.argmax, paddle.argmin]:
np.testing.assert_allclose(out.numpy(), x.numpy())
out_empty_list = api(x, [])
self.assertEqual(out_empty_list, out)

if x.grad is not None:
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [])
Expand Down Expand Up @@ -250,7 +253,9 @@ def test_static_reduce(self):
res = exe.run(main_prog, fetch_list=fetch_list)
self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, ())
np.testing.assert_allclose(res[0], res[1])
if api not in [paddle.argmax, paddle.argmin]:
np.testing.assert_allclose(res[0], res[1])

if len(res) > 2:
self.assertEqual(res[2].shape, ())
self.assertEqual(res[3].shape, ())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def test_dygraph_unary(self):
paddle.logsumexp,
paddle.all,
paddle.any,
paddle.argmax,
paddle.argmin,
]


Expand All @@ -153,7 +155,8 @@ def test_dygraph_reduce(self):

self.assertEqual(x.shape, [])
self.assertEqual(out.shape, [])
np.testing.assert_allclose(out.numpy(), x.numpy())
if api not in [paddle.argmax, paddle.argmin]:
np.testing.assert_allclose(out.numpy(), x.numpy())
if x.grad is not None:
self.assertEqual(x.grad.shape, [])
self.assertEqual(out.grad.shape, [])
Expand Down

0 comments on commit e4e94a8

Please sign in to comment.