Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Zero-Dim] Support paddle.all/any/min/prod/logsumexp/amax/amin/quantile/some loss output 0D, test=allcase #53051

Merged
merged 1 commit into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/operators/assert_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class AssertOp : public framework::OperatorBase {
"Input(Condition) of AssertOp is not found."));
const phi::DenseTensor &cond = cond_var_ptr->Get<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(
cond.dims(),
phi::make_ddim({1}),
cond.numel(),
1,
platform::errors::InvalidArgument(
"The numel of Input(Condition) of AssertOp must be 1. But now "
"the Condition's shape is %s.",
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/operators/reduce_ops/reduce_max_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@ class ReduceMaxCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
} // namespace operators
} // namespace paddle

DECLARE_INFER_SHAPE_FUNCTOR(
reduce_max,
ReduceMaxInferShapeFunctor,
PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase));
DECLARE_INFER_SHAPE_FUNCTOR(reduce_max,
ReduceMaxInferShapeFunctor,
PD_INFER_META(phi::OriginReduceInferMetaBase));

REGISTER_OPERATOR(
reduce_max,
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,9 @@ class __reduce_meanMaker__ : public ops::ReduceBaseOpMaker {
virtual std::string GetOpType() const { return "Reduce reduce_mean"; }
};

DECLARE_INFER_SHAPE_FUNCTOR(
reduce_mean,
ReduceMeanInferShapeFunctor,
PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase));
DECLARE_INFER_SHAPE_FUNCTOR(reduce_mean,
ReduceMeanInferShapeFunctor,
PD_INFER_META(phi::OriginReduceInferMetaBase));

REGISTER_OPERATOR(reduce_mean,
ops::ReduceBaseOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,7 @@ void prod_grad(const Tensor& x,
if (!keep_dim) {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@
args : (Tensor x, IntArray axis={}, bool keepdim=false)
output : Tensor(out)
infer_meta :
func : ReduceIntArrayAxisInferMeta
func : OriginReduceInferMeta
kernel :
func : max
backward : max_grad
Expand Down Expand Up @@ -820,7 +820,7 @@
args : (Tensor x, IntArray axis={}, bool keepdim=false)
output : Tensor(out)
infer_meta :
func : ReduceIntArrayAxisInferMeta
func : OriginReduceInferMeta
kernel :
func : mean
backward : mean_grad
Expand Down
197 changes: 160 additions & 37 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2146,7 +2146,7 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x,
}

void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dims(phi::make_ddim({1}));
out->set_dims(phi::make_ddim({}));
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
Expand Down Expand Up @@ -3050,29 +3050,19 @@ DDim ReduceInferDim(const MetaTensor& x,
reduce_all = reduce_all || full_dim;

std::vector<int64_t> out_dim_vector;
if (keep_dim) {
for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
if (keep_dim) {
out_dim_vector.push_back(1);
} else {
out_dim_vector.push_back(x.dims().at(i));
}
}
} else {
for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
continue;
} else {
out_dim_vector.push_back(x.dims().at(i));
}
}

if (x_rank > 0 && out_dim_vector.size() == 0) {
out_dim_vector.push_back(1);
} else {
out_dim_vector.push_back(x.dims().at(i));
}
}
DDim out_dim = phi::make_ddim(out_dim_vector);

DDim out_dim = phi::make_ddim(out_dim_vector);
return out_dim;
}

Expand All @@ -3086,14 +3076,14 @@ DDim ReduceInferDimForIntArrayAxis(const MetaTensor& x,
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), 1);
} else {
vec_dim = {1};
vec_dim = {};
}
} else {
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), -1);
} else {
auto x_rank = static_cast<size_t>(x.dims().size());
if (vec_axis.size() >= x_rank) {
if (vec_axis.size() > x_rank) {
vec_dim = {-1};
} else {
vec_dim = std::vector<int64_t>(x.dims().size() - vec_axis.size(), -1);
Expand Down Expand Up @@ -3125,22 +3115,6 @@ void ReduceInferMetaBase(const MetaTensor& x,
out->set_layout(x.layout());
}

void ReduceIntArrayAxisInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config) {
if (config.is_runtime || !axis.FromTensor()) {
ReduceInferMetaBase(x, axis.GetData(), keep_dim, reduce_all, out);
} else {
DDim out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
}

void ReduceIntArrayAxisInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
Expand All @@ -3153,6 +3127,23 @@ void ReduceIntArrayAxisInferMeta(const MetaTensor& x,
ReduceIntArrayAxisInferMetaBase(x, axis, keep_dim, reduce_all, out, config);
}

void ReduceIntArrayAxisInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config) {
DDim out_dim;
if (config.is_runtime || !axis.FromTensor()) {
out_dim = ReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
} else {
out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
}
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}

void ReduceScatterInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) {
auto dim = x.dims();
if (dim[0] > 0 || dim[0] < -1) {
Expand Down Expand Up @@ -3951,6 +3942,105 @@ void StridedSliceInferMeta(const MetaTensor& x,
x, axes, starts, ends, strides, infer_flags, decrease_axis, out, config);
}

// TODO(zhouwei): OriginReduceInferDim doesn't support 0D, remove in future
DDim OriginReduceInferDim(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all) {
auto x_rank = x.dims().size();

std::vector<int64_t> formated_axis = axis;
for (size_t i = 0; i < axis.size(); ++i) {
if (x_rank == 0) {
PADDLE_ENFORCE_EQ(
axis[i] == 0 || axis[i] == -1,
true,
phi::errors::InvalidArgument(
"When input 0D Tensor, the axis can only be -1, 0, None or []"));
} else {
PADDLE_ENFORCE_LT(axis[i],
x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
PADDLE_ENFORCE_GE(axis[i],
-x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
}

if (axis[i] < 0) {
formated_axis[i] = axis[i] + x_rank;
}
}

bool full_dim = true;
std::set<int64_t> dims_set(formated_axis.begin(), formated_axis.end());
for (int64_t i = 0; i < x_rank; ++i) {
if (dims_set.find(i) == dims_set.end()) {
full_dim = false;
break;
}
}
reduce_all = reduce_all || full_dim;

std::vector<int64_t> out_dim_vector;
for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
if (keep_dim) {
out_dim_vector.push_back(1);
} else {
continue;
}
} else {
out_dim_vector.push_back(x.dims().at(i));
}
}
if (x_rank > 0 && out_dim_vector.size() == 0) {
out_dim_vector.push_back(1);
}

DDim out_dim = phi::make_ddim(out_dim_vector);
return out_dim;
}

// TODO(zhouwei): OriginReduceInferDim doesn't support 0D, remove in future
DDim OriginReduceInferDimForIntArrayAxis(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all) {
std::vector<int64_t> vec_axis = axis.GetData();
std::vector<int64_t> vec_dim;
if (reduce_all) {
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), 1);
} else {
vec_dim = {1};
}
} else {
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), -1);
} else {
auto x_rank = static_cast<size_t>(x.dims().size());
if (vec_axis.size() >= x_rank) {
vec_dim = {-1};
} else {
vec_dim = std::vector<int64_t>(x.dims().size() - vec_axis.size(), -1);
}
}
}
return phi::make_ddim(vec_dim);
}

/* Why not use SumRawInferMeta directly?
Because we need make InferMetaFunction's args follow the design of
ops.yaml
Expand All @@ -3977,9 +4067,10 @@ void SumRawInferMeta(const MetaTensor& x,
MetaConfig config) {
DDim out_dim;
if (config.is_runtime || !axis.FromTensor()) {
out_dim = ReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
out_dim = OriginReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
} else {
out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
out_dim =
OriginReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
}

DataType out_dtype;
Expand All @@ -3998,6 +4089,38 @@ void SumRawInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}

// TODO(zhouwei): OriginReduce doesn't support 0D, remove in future
void OriginReduceInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
MetaTensor* out,
MetaConfig config) {
bool reduce_all = false;
if (axis.size() == 0) {
reduce_all = true;
}
OriginReduceInferMetaBase(x, axis, keep_dim, reduce_all, out, config);
}

// TODO(zhouwei): OriginReduce doesn't support 0D, remove in future
void OriginReduceInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config) {
DDim out_dim;
if (config.is_runtime || !axis.FromTensor()) {
out_dim = OriginReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
} else {
out_dim =
OriginReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
}
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,19 @@ void SumRawInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void OriginReduceInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
MetaTensor* out,
MetaConfig config = MetaConfig());

void OriginReduceInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config = MetaConfig());

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
Expand Down
10 changes: 6 additions & 4 deletions paddle/phi/kernels/cpu/mean_all_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ void MeanAllGradKernel(const Context& dev_ctx,
out_grad.numel()));
dev_ctx.template Alloc<T>(x_grad);

T ig_size = static_cast<T>(x_grad->numel());
Eigen::DSizes<int, 1> bcast(static_cast<int>(ig_size));
EigenVector<T>::Flatten(*x_grad).device(*dev_ctx.eigen_device()) =
(EigenVector<T>::From(out_grad) / ig_size).broadcast(bcast);
T x_numel = static_cast<T>(x_grad->numel());
Eigen::DSizes<int, 1> bcast(static_cast<int>(x_numel));
auto eigen_x = EigenVector<T>::Flatten(*x_grad);
auto eigen_dout = EigenVector<T>::Flatten(out_grad);
eigen_x.device(*dev_ctx.eigen_device()) =
(eigen_dout / x_numel).broadcast(bcast);
}

} // namespace phi
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/distributed/auto_parallel/cost/base_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License

from collections import OrderedDict
from functools import reduce

import numpy as np

import paddle
from paddle.utils.flops import flops
Expand Down Expand Up @@ -807,7 +808,7 @@ def comm_count(self):
factor = 8
else:
raise ValueError(f"Unsupported comm dtype {dtype}")
comm_count = reduce(lambda x, y: x * y, shape) * factor
comm_count = int(np.prod(shape)) * factor
self._comm_count = comm_count

return self._comm_count
Expand Down
Loading