Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] Fix 0-dim tensor for arg_min_max op. #49570

Merged
merged 15 commits into from
Feb 1, 2023
50 changes: 34 additions & 16 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,26 +160,39 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
auto int_axis = axis.to<int64_t>();
const auto& x_dims = x.dims();

PADDLE_ENFORCE_GE(
int_axis,
-x_dims.size(),
phi::errors::InvalidArgument("'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
int_axis,
-x_dims.size()));
PADDLE_ENFORCE_LT(int_axis,
x_dims.size(),
phi::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
int_axis,
x_dims.size()));
auto x_rank = x.dims().size();
if (x_rank > 0) {
PADDLE_ENFORCE_GE(int_axis,
-x_rank,
phi::errors::InvalidArgument(
"'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d).",
int_axis,
-x_rank));
PADDLE_ENFORCE_LT(
int_axis,
x_rank,
phi::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X).",
int_axis,
x_rank));
} else {
// 0-dim tensor
PADDLE_ENFORCE_EQ(
int_axis == 0 || int_axis == -1,
true,
phi::errors::InvalidArgument(
"'axis'(%d) must be 0 or -1 if input tensor is 0-dim.", int_axis));
}

auto x_rank = x_dims.size();
if (int_axis < 0) int_axis += x_rank;

if (config.is_runtime) {
if (dtype == phi::TransToProtoVarType(DataType::INT32)) {
int64_t all_element_num = 0;
if (flatten) {
if (x_rank == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个也不用写分支,因为phi::product(x_dims); 里面已经支持了0D的product计算结果是1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

额,这里有个flatten的配置,怕这个 为 false的话,后面可能有问题。

Copy link
Contributor

@zhwesky2010 zhwesky2010 Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0D的时候axis只能为None,就是flatten的情况

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

all_element_num = 1;
} else if (flatten) {
all_element_num = phi::product(x_dims);
} else {
all_element_num = x_dims[int_axis];
Expand All @@ -195,8 +208,12 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
INT_MAX));
}
}

std::vector<int64_t> vec;
if (flatten) {

if (x_rank == 0) {
// vec is set to empty
} else if (flatten) {
vec.emplace_back(static_cast<int64_t>(1));
} else {
for (int64_t i = 0; i < int_axis; i++) vec.emplace_back(x_dims[i]);
Expand All @@ -205,6 +222,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
}
for (int64_t i = int_axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
}

out->set_dims(phi::make_ddim(vec));
if (dtype == 2) {
out->set_dtype(DataType::INT32);
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/kernels/cpu/arg_min_max_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ struct VisitDataArgMinMaxFunctor {
if (axis < 0) new_axis = axis + x_dims.size();
}

// For 0D Tensor
if (x.dims().size() == 0) {
phi::funcs::set_constant(dev_ctx, out, 0);
return;
}

#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<Context, T, Tout, rank, EnumArgMinMaxValue> functor##rank; \
functor##rank(dev_ctx, x, out, x_dims, new_axis, new_keepdims)
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/kernels/gpu/arg_min_max_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace cub = hipcub;

#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {

namespace { // NOLINT
Expand Down Expand Up @@ -180,6 +181,12 @@ struct VisitDataCudaArgMinMaxFunctor {
x_dims = x.dims();
if (axis < 0) new_axis = axis + x.dims().size();
}
// For 0D Tensor
if (x.dims().size() == 0) {
dev_ctx.template Alloc<IndType>(out);
phi::funcs::set_constant(dev_ctx, out, 0);
return;
}

int64_t numel = x.numel();
int64_t groups = numel / x_dims[new_axis];
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/kernels/xpu/arg_min_max_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

Expand All @@ -39,7 +40,12 @@ void ArgMaxKernel(const Context& dev_ctx,
DataType::INT64,
DataType::INT32,
dtype));
// TODO(ZHUI): fix dtype of out
dev_ctx.template Alloc<int64_t>(out);
if (x.dims().size() == 0) {
phi::funcs::set_constant(dev_ctx, out, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XPU的设置为常数:

xpu::constant<T>(
        dev_ctx.x_context(), dx_data, x->numel(), static_cast<T>(1.0));

return;
}

DDim x_dims;
int axis_val = axis.to<int>();
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
paddle.nn.functional.silu,
paddle.nn.functional.tanh,
paddle.nn.functional.dropout,
paddle.argmax,
ZHUI marked this conversation as resolved.
Show resolved Hide resolved
paddle.argmin,
paddle.cosh,
paddle.sinh,
paddle.abs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
paddle.sinh,
paddle.abs,
paddle.acos,
paddle.argmax,
paddle.argmin,
paddle.asin,
paddle.atan,
paddle.ceil,
Expand Down