Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] Support output 0D for argmin/argmax/median/kthvalue/mode/equal_all/allclose, test=allcase #51889

Merged
merged 3 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,7 @@ void AllValueCompareInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config) {
detail::BinarySameInputDimsCheck(x, y, config);

auto x_dims = x.dims();
auto y_dims = y.dims();
if (x_dims.size() == 0 && y_dims.size() == 0) {
out->set_dims(phi::make_ddim({}));
} else {
out->set_dims(phi::make_ddim({1}));
}
out->set_dims(phi::make_ddim({}));
out->set_dtype(DataType::BOOL);
}

Expand Down Expand Up @@ -403,12 +396,7 @@ void CompareAllInferMeta(const MetaTensor& x,
errors::InvalidArgument(
"The size of dim_y should not be greater than dim_x's."));
out->share_lod(x);
if (!x.dims().size() || !y.dims().size()) {
out->set_dims(make_ddim({}));
} else {
out->set_dims(make_ddim({1}));
}
out->set_dtype(DataType::BOOL);
out->set_dims(make_ddim({}));
}

void ComplexInferMeta(const MetaTensor& x,
Expand Down
42 changes: 27 additions & 15 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,11 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
if (!config.is_runtime && axis.FromTensor()) {
std::vector<int64_t> vec;
if (flatten) {
vec = {1};
if (keepdims) {
vec = std::vector<int64_t>(x.dims().size(), -1);
} else {
vec = {};
}
} else {
if (keepdims) {
vec = std::vector<int64_t>(x.dims().size(), -1);
Expand All @@ -169,7 +173,6 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
const auto& x_dims = x.dims();

auto x_rank = x.dims().size();
auto zero_dim_tensor = x_rank == 0;
if (x_rank > 0) {
PADDLE_ENFORCE_GE(int_axis,
-x_rank,
Expand Down Expand Up @@ -200,7 +203,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
if (config.is_runtime) {
if (dtype == phi::TransToProtoVarType(DataType::INT32)) {
int64_t all_element_num = 0;
if (flatten || zero_dim_tensor) {
if (flatten) {
all_element_num = phi::product(x_dims);
} else {
all_element_num = x_dims[int_axis];
Expand All @@ -218,11 +221,12 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
}

std::vector<int64_t> vec;

if (x_rank == 0) {
// vec is set to empty
} else if (flatten) {
vec.emplace_back(static_cast<int64_t>(1));
if (flatten) {
if (keepdims) {
vec = std::vector<int64_t>(x.dims().size(), 1);
} else {
vec = {};
}
} else {
for (int64_t i = 0; i < int_axis; i++) vec.emplace_back(x_dims[i]);
if (keepdims) {
Expand Down Expand Up @@ -1838,21 +1842,29 @@ void KthvalueInferMeta(const MetaTensor& x,
MetaConfig config) {
auto input_dims = x.dims();
const int& dim_size = input_dims.size();
PADDLE_ENFORCE_LE(axis,
dim_size,
phi::errors::InvalidArgument(
"the axis must be [-%d, %d), but received %d .",
dim_size,
dim_size,
axis));
if (dim_size > 0) {
PADDLE_ENFORCE_LT(axis,
dim_size,
phi::errors::InvalidArgument(
"the axis must be [-%d, %d), but received %d .",
dim_size,
dim_size,
axis));
PADDLE_ENFORCE_GE(axis,
-dim_size,
phi::errors::InvalidArgument(
"the axis must be [-%d, %d), but received %d .",
dim_size,
dim_size,
axis));
} else if (dim_size == 0) {
// 0-dim tensor
PADDLE_ENFORCE_EQ(axis == 0 || axis == -1,
true,
phi::errors::InvalidArgument(
"'axis'(%d) must be 0 or -1 if input tensor is "
"0-dim.",
axis));
}
if (axis < 0) axis += dim_size;
PADDLE_ENFORCE_GE(
Expand Down
74 changes: 38 additions & 36 deletions paddle/phi/kernels/cpu/arg_min_max_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,34 @@ template <typename Context,
ArgMinMaxType argMinMaxValue>
struct ArgMinMaxFunctor {};

#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
template <typename Context, typename T, typename Tout, int64_t Rank> \
struct ArgMinMaxFunctor<Context, T, Tout, Rank, enum_argminmax_value> { \
void operator()(const Context& dev_ctx, \
const DenseTensor& in, \
DenseTensor* out, \
phi::DDim x_dims, \
int64_t axis, \
bool keepdims) { \
auto in_eigen = EigenTensor<T, Rank>::From(in, x_dims); \
if (keepdims) { \
auto out_eigen = EigenTensor<Tout, Rank>::From(*out); \
out_eigen.device(*(dev_ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} else { \
auto out_eigen = EigenTensor<Tout, Rank - 1>::From(*out); \
out_eigen.device(*(dev_ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
} \
#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
template <typename Context, typename T, typename Tout, int64_t Rank> \
struct ArgMinMaxFunctor<Context, T, Tout, Rank, enum_argminmax_value> { \
void operator()(const Context& dev_ctx, \
const DenseTensor& in, \
DenseTensor* out, \
phi::DDim x_dims, \
phi::DDim out_dims, \
int64_t axis, \
bool keepdims, \
bool flatten) { \
auto in_eigen = EigenTensor<T, Rank>::From(in, x_dims); \
if (flatten) { \
auto out_eigen = EigenTensor<Tout, 0>::From(*out, out_dims); \
out_eigen.device(*(dev_ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} else { \
if (keepdims) { \
auto out_eigen = EigenTensor<Tout, Rank>::From(*out, out_dims); \
out_eigen.device(*(dev_ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} else { \
auto out_eigen = EigenTensor<Tout, Rank - 1>::From(*out, out_dims); \
out_eigen.device(*(dev_ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
} \
} \
}

DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin);
Expand Down Expand Up @@ -81,32 +89,30 @@ struct VisitDataArgMinMaxFunctor {
template <typename Tout>
void apply() const {
dev_ctx.template Alloc<Tout>(out);
bool new_keepdims = keepdims;
if (flatten) new_keepdims = true;

// if flatten, will construct the new dims for the cacluate
phi::DDim x_dims;
phi::DDim out_dims;
int new_axis = axis;
if (flatten) {
// always reduce 1D -> 0D
x_dims = phi::make_ddim({x.numel()});
// if flatten, the axis just as 0
out_dims = phi::make_ddim({});
new_axis = 0;
} else {
x_dims = x.dims();
out_dims = out->dims();
if (axis < 0) new_axis = axis + x_dims.size();
}

// For 0D Tensor
if (x.dims().size() == 0) {
phi::funcs::set_constant(dev_ctx, out, 0);
return;
}

#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<Context, T, Tout, rank, EnumArgMinMaxValue> functor##rank; \
functor##rank(dev_ctx, x, out, x_dims, new_axis, new_keepdims)
functor##rank(dev_ctx, x, out, x_dims, out_dims, new_axis, keepdims, flatten)

switch (x_dims.size()) {
case 0:
phi::funcs::set_constant(dev_ctx, out, 0);
return;
case 1:
CALL_ARG_MINMAX_FUNCTOR(1);
break;
Expand Down Expand Up @@ -195,9 +201,7 @@ PD_REGISTER_KERNEL(argmin,
int32_t,
int64_t,
int16_t,
uint8_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
uint8_t) {}

PD_REGISTER_KERNEL(argmax,
CPU,
Expand All @@ -208,6 +212,4 @@ PD_REGISTER_KERNEL(argmax,
int32_t,
int64_t,
int16_t,
uint8_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
uint8_t) {}
6 changes: 4 additions & 2 deletions python/paddle/fluid/dygraph/learning_rate_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ def state_dict(self):
if isinstance(value, Variable):
assert (
value.size == 1
), "size of Variable in state_dict must be 1"
value = float(value)
), "the size of Variable in state_dict must be 1, but its size is {} with shape {}".format(
value.size, value.shape
)
value = value.item()
state_dict[key] = value

return state_dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, train_id):

def forward(self, x):
is_use = (
paddle.equal_all(x, paddle.ones(shape=(batch, in_dim))).numpy()[0]
paddle.equal_all(x, paddle.ones(shape=(batch, in_dim))).item()
and self.trainer_id == 1
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, train_id):

def forward(self, x):
is_use = (
paddle.equal_all(x, paddle.ones(shape=(batch, in_dim))).numpy()[0]
paddle.equal_all(x, paddle.ones(shape=(batch, in_dim))).item()
and self.trainer_id == 1
)

Expand Down
32 changes: 17 additions & 15 deletions python/paddle/fluid/tests/unittests/test_allclose_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import paddle
import paddle.fluid as fluid

paddle.enable_static()


class TestAllcloseLayer(unittest.TestCase):
def allclose_check(self, use_cuda, dtype='float32'):
Expand All @@ -44,31 +46,31 @@ def allclose_check(self, use_cuda, dtype='float32'):
result_v, result_nan_v = exe.run(
feed={'a': x, 'b': y}, fetch_list=[result, result_nan]
)
self.assertEqual(result_v[0], False)
self.assertEqual(result_nan_v[0], False)
self.assertEqual(result_v, False)
self.assertEqual(result_nan_v, False)

x = np.array([10000.0, 1e-08]).astype(dtype)
y = np.array([10000.1, 1e-09]).astype(dtype)
result_v, result_nan_v = exe.run(
feed={'a': x, 'b': y}, fetch_list=[result, result_nan]
)
self.assertEqual(result_v[0], True)
self.assertEqual(result_nan_v[0], True)
self.assertEqual(result_v, True)
self.assertEqual(result_nan_v, True)

x = np.array([1.0, float('nan')]).astype(dtype)
y = np.array([1.0, float('nan')]).astype(dtype)
result_v, result_nan_v = exe.run(
feed={'a': x, 'b': y}, fetch_list=[result, result_nan]
)
self.assertEqual(result_v[0], False)
self.assertEqual(result_nan_v[0], True)
self.assertEqual(result_v, False)
self.assertEqual(result_nan_v, True)

# for corner case
x = np.array([10.1, 10.1]).astype(dtype)
y = np.array([10, 10]).astype(dtype)
(result_c,) = exe.run(feed={'a': x, 'b': y}, fetch_list=[result_corner])
corner_res = dtype == 'float64'
self.assertEqual(result_c[0], corner_res)
self.assertEqual(result_c, corner_res)

def test_allclose_cpu_fp32(self):
main = fluid.Program()
Expand Down Expand Up @@ -123,7 +125,7 @@ def test_dygraph_mode(self):
equal_nan=False,
name='test_1',
)
self.assertEqual(ret_1.numpy()[0], False)
self.assertEqual(ret_1.numpy(), False)
ret_1 = paddle.allclose(
x_v_1,
y_v_1,
Expand All @@ -132,7 +134,7 @@ def test_dygraph_mode(self):
equal_nan=True,
name='test_2',
)
self.assertEqual(ret_1.numpy()[0], False)
self.assertEqual(ret_1.numpy(), False)
x_v_2 = paddle.to_tensor(x_2)
y_v_2 = paddle.to_tensor(y_2)
ret_2 = paddle.allclose(
Expand All @@ -143,7 +145,7 @@ def test_dygraph_mode(self):
equal_nan=False,
name='test_3',
)
self.assertEqual(ret_2.numpy()[0], True)
self.assertEqual(ret_2.numpy(), True)
ret_2 = paddle.allclose(
x_v_2,
y_v_2,
Expand All @@ -152,7 +154,7 @@ def test_dygraph_mode(self):
equal_nan=True,
name='test_4',
)
self.assertEqual(ret_2.numpy()[0], True)
self.assertEqual(ret_2.numpy(), True)
x_v_3 = paddle.to_tensor(x_3)
y_v_3 = paddle.to_tensor(y_3)
ret_3 = paddle.allclose(
Expand All @@ -163,7 +165,7 @@ def test_dygraph_mode(self):
equal_nan=False,
name='test_5',
)
self.assertEqual(ret_3.numpy()[0], False)
self.assertEqual(ret_3.numpy(), False)
ret_3 = paddle.allclose(
x_v_3,
y_v_3,
Expand All @@ -172,20 +174,20 @@ def test_dygraph_mode(self):
equal_nan=True,
name='test_6',
)
self.assertEqual(ret_3.numpy()[0], True)
self.assertEqual(ret_3.numpy(), True)
# for corner case
x_v_4 = paddle.to_tensor(x_4)
y_v_4 = paddle.to_tensor(y_4)
ret_4 = paddle.allclose(
x_v_4, y_v_4, rtol=0.01, atol=0.0, name='test_7'
)
self.assertEqual(ret_4.numpy()[0], False)
self.assertEqual(ret_4.numpy(), False)
x_v_5 = paddle.to_tensor(x_5)
y_v_5 = paddle.to_tensor(y_5)
ret_5 = paddle.allclose(
x_v_5, y_v_5, rtol=0.015, atol=0.0, name='test_8'
)
self.assertEqual(ret_5.numpy()[0], True)
self.assertEqual(ret_5.numpy(), True)


if __name__ == "__main__":
Expand Down
6 changes: 1 addition & 5 deletions python/paddle/fluid/tests/unittests/test_median.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ def static_single_test_median(self, lis_test):
paddle.enable_static()
x, axis, keepdims = lis_test
res_np = np.median(x, axis=axis, keepdims=keepdims)
if not isinstance(res_np, np.ndarray):
res_np = np.array([res_np])
main_program = Program()
startup_program = Program()
exe = paddle.static.Executor()
Expand All @@ -47,10 +45,8 @@ def static_single_test_median(self, lis_test):
def dygraph_single_test_median(self, lis_test):
x, axis, keepdims = lis_test
res_np = np.median(x, axis=axis, keepdims=keepdims)
if not isinstance(res_np, np.ndarray):
res_np = np.array([res_np])
res_pd = paddle.median(paddle.to_tensor(x), axis, keepdims)
self.check_numpy_res(res_pd.numpy(), res_np)
self.check_numpy_res(res_pd.numpy(False), res_np)

def test_median_static(self):
h = 3
Expand Down
Loading