Skip to content

Commit

Permalink
[Zero-Dim] Support all/any/min/prod/logsumexp/amax/amin/some loss out…
Browse files Browse the repository at this point in the history
…put 0D,test=allcase
  • Loading branch information
zhwesky2010 committed Apr 19, 2023
1 parent 48ccb78 commit 705d0a8
Show file tree
Hide file tree
Showing 19 changed files with 371 additions and 105 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/operators/assert_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class AssertOp : public framework::OperatorBase {
"Input(Condition) of AssertOp is not found."));
const phi::DenseTensor &cond = cond_var_ptr->Get<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(
cond.dims(),
phi::make_ddim({1}),
cond.numel(),
1,
platform::errors::InvalidArgument(
"The numel of Input(Condition) of AssertOp must be 1. But now "
"the Condition's shape is %s.",
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/operators/reduce_ops/reduce_max_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@ class ReduceMaxCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
} // namespace operators
} // namespace paddle

DECLARE_INFER_SHAPE_FUNCTOR(
reduce_max,
ReduceMaxInferShapeFunctor,
PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase));
DECLARE_INFER_SHAPE_FUNCTOR(reduce_max,
ReduceMaxInferShapeFunctor,
PD_INFER_META(phi::OriginReduceInferMetaBase));

REGISTER_OPERATOR(
reduce_max,
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,9 @@ class __reduce_meanMaker__ : public ops::ReduceBaseOpMaker {
virtual std::string GetOpType() const { return "Reduce reduce_mean"; }
};

DECLARE_INFER_SHAPE_FUNCTOR(
reduce_mean,
ReduceMeanInferShapeFunctor,
PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase));
DECLARE_INFER_SHAPE_FUNCTOR(reduce_mean,
ReduceMeanInferShapeFunctor,
PD_INFER_META(phi::OriginReduceInferMetaBase));

REGISTER_OPERATOR(reduce_mean,
ops::ReduceBaseOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,7 @@ void prod_grad(const Tensor& x,
if (!keep_dim) {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@
args : (Tensor x, IntArray axis={}, bool keepdim=false)
output : Tensor(out)
infer_meta :
func : ReduceIntArrayAxisInferMeta
func : OriginReduceInferMeta
kernel :
func : max
backward : max_grad
Expand Down Expand Up @@ -820,7 +820,7 @@
args : (Tensor x, IntArray axis={}, bool keepdim=false)
output : Tensor(out)
infer_meta :
func : ReduceIntArrayAxisInferMeta
func : OriginReduceInferMeta
kernel :
func : mean
backward : mean_grad
Expand Down
197 changes: 160 additions & 37 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2146,7 +2146,7 @@ void MaxPoolWithIndexInferMeta(const MetaTensor& x,
}

void MeanAllInferMeta(const MetaTensor& x, MetaTensor* out) {
out->set_dims(phi::make_ddim({1}));
out->set_dims(phi::make_ddim({}));
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
Expand Down Expand Up @@ -3050,29 +3050,19 @@ DDim ReduceInferDim(const MetaTensor& x,
reduce_all = reduce_all || full_dim;

std::vector<int64_t> out_dim_vector;
if (keep_dim) {
for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
if (keep_dim) {
out_dim_vector.push_back(1);
} else {
out_dim_vector.push_back(x.dims().at(i));
}
}
} else {
for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
continue;
} else {
out_dim_vector.push_back(x.dims().at(i));
}
}

if (x_rank > 0 && out_dim_vector.size() == 0) {
out_dim_vector.push_back(1);
} else {
out_dim_vector.push_back(x.dims().at(i));
}
}
DDim out_dim = phi::make_ddim(out_dim_vector);

DDim out_dim = phi::make_ddim(out_dim_vector);
return out_dim;
}

Expand All @@ -3086,14 +3076,14 @@ DDim ReduceInferDimForIntArrayAxis(const MetaTensor& x,
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), 1);
} else {
vec_dim = {1};
vec_dim = {};
}
} else {
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), -1);
} else {
auto x_rank = static_cast<size_t>(x.dims().size());
if (vec_axis.size() >= x_rank) {
if (vec_axis.size() > x_rank) {
vec_dim = {-1};
} else {
vec_dim = std::vector<int64_t>(x.dims().size() - vec_axis.size(), -1);
Expand Down Expand Up @@ -3125,22 +3115,6 @@ void ReduceInferMetaBase(const MetaTensor& x,
out->set_layout(x.layout());
}

void ReduceIntArrayAxisInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config) {
if (config.is_runtime || !axis.FromTensor()) {
ReduceInferMetaBase(x, axis.GetData(), keep_dim, reduce_all, out);
} else {
DDim out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
}

void ReduceIntArrayAxisInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
Expand All @@ -3153,6 +3127,23 @@ void ReduceIntArrayAxisInferMeta(const MetaTensor& x,
ReduceIntArrayAxisInferMetaBase(x, axis, keep_dim, reduce_all, out, config);
}

void ReduceIntArrayAxisInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config) {
DDim out_dim;
if (config.is_runtime || !axis.FromTensor()) {
out_dim = ReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
} else {
out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
}
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}

void ReduceScatterInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) {
auto dim = x.dims();
if (dim[0] > 0 || dim[0] < -1) {
Expand Down Expand Up @@ -3951,6 +3942,105 @@ void StridedSliceInferMeta(const MetaTensor& x,
x, axes, starts, ends, strides, infer_flags, decrease_axis, out, config);
}

// TODO(zhouwei): OriginReduceInferDim doesn't support 0D, remove in future
DDim OriginReduceInferDim(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
bool reduce_all) {
auto x_rank = x.dims().size();

std::vector<int64_t> formated_axis = axis;
for (size_t i = 0; i < axis.size(); ++i) {
if (x_rank == 0) {
PADDLE_ENFORCE_EQ(
axis[i] == 0 || axis[i] == -1,
true,
phi::errors::InvalidArgument(
"When input 0D Tensor, the axis can only be -1, 0, None or []"));
} else {
PADDLE_ENFORCE_LT(axis[i],
x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
PADDLE_ENFORCE_GE(axis[i],
-x_rank,
errors::InvalidArgument(
"The reduce dim index %d should be in the "
"range [ -dimension(X), dimension(X) ) "
"which dimesion = %d. But received dim index = %d.",
i,
x_rank,
axis[i]));
}

if (axis[i] < 0) {
formated_axis[i] = axis[i] + x_rank;
}
}

bool full_dim = true;
std::set<int64_t> dims_set(formated_axis.begin(), formated_axis.end());
for (int64_t i = 0; i < x_rank; ++i) {
if (dims_set.find(i) == dims_set.end()) {
full_dim = false;
break;
}
}
reduce_all = reduce_all || full_dim;

std::vector<int64_t> out_dim_vector;
for (int64_t i = 0; i < x_rank; ++i) {
if (reduce_all || dims_set.find(i) != dims_set.end()) {
if (keep_dim) {
out_dim_vector.push_back(1);
} else {
continue;
}
} else {
out_dim_vector.push_back(x.dims().at(i));
}
}
if (x_rank > 0 && out_dim_vector.size() == 0) {
out_dim_vector.push_back(1);
}

DDim out_dim = phi::make_ddim(out_dim_vector);
return out_dim;
}

// TODO(zhouwei): OriginReduceInferDim doesn't support 0D, remove in future
DDim OriginReduceInferDimForIntArrayAxis(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all) {
std::vector<int64_t> vec_axis = axis.GetData();
std::vector<int64_t> vec_dim;
if (reduce_all) {
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), 1);
} else {
vec_dim = {1};
}
} else {
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), -1);
} else {
auto x_rank = static_cast<size_t>(x.dims().size());
if (vec_axis.size() >= x_rank) {
vec_dim = {-1};
} else {
vec_dim = std::vector<int64_t>(x.dims().size() - vec_axis.size(), -1);
}
}
}
return phi::make_ddim(vec_dim);
}

/* Why not use SumRawInferMeta directly?
Because we need make InferMetaFunction's args follow the design of
ops.yaml
Expand All @@ -3977,9 +4067,10 @@ void SumRawInferMeta(const MetaTensor& x,
MetaConfig config) {
DDim out_dim;
if (config.is_runtime || !axis.FromTensor()) {
out_dim = ReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
out_dim = OriginReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
} else {
out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
out_dim =
OriginReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
}

DataType out_dtype;
Expand All @@ -3998,6 +4089,38 @@ void SumRawInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}

// TODO(zhouwei): OriginReduce doesn't support 0D, remove in future
void OriginReduceInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
MetaTensor* out,
MetaConfig config) {
bool reduce_all = false;
if (axis.size() == 0) {
reduce_all = true;
}
OriginReduceInferMetaBase(x, axis, keep_dim, reduce_all, out, config);
}

// TODO(zhouwei): OriginReduce doesn't support 0D, remove in future
void OriginReduceInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config) {
DDim out_dim;
if (config.is_runtime || !axis.FromTensor()) {
out_dim = OriginReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
} else {
out_dim =
OriginReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
}
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,19 @@ void SumRawInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void OriginReduceInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
MetaTensor* out,
MetaConfig config = MetaConfig());

void OriginReduceInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config = MetaConfig());

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
MetaTensor* u,
Expand Down
10 changes: 6 additions & 4 deletions paddle/phi/kernels/cpu/mean_all_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ void MeanAllGradKernel(const Context& dev_ctx,
out_grad.numel()));
dev_ctx.template Alloc<T>(x_grad);

T ig_size = static_cast<T>(x_grad->numel());
Eigen::DSizes<int, 1> bcast(static_cast<int>(ig_size));
EigenVector<T>::Flatten(*x_grad).device(*dev_ctx.eigen_device()) =
(EigenVector<T>::From(out_grad) / ig_size).broadcast(bcast);
T x_numel = static_cast<T>(x_grad->numel());
Eigen::DSizes<int, 1> bcast(static_cast<int>(x_numel));
auto eigen_x = EigenVector<T>::Flatten(*x_grad);
auto eigen_dout = EigenVector<T>::Flatten(out_grad);
eigen_x.device(*dev_ctx.eigen_device()) =
(eigen_dout / x_numel).broadcast(bcast);
}

} // namespace phi
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/distributed/auto_parallel/cost/base_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License

from collections import OrderedDict
from functools import reduce

import numpy as np

import paddle
from paddle.utils.flops import flops
Expand Down Expand Up @@ -807,7 +808,7 @@ def comm_count(self):
factor = 8
else:
raise ValueError(f"Unsupported comm dtype {dtype}")
comm_count = reduce(lambda x, y: x * y, shape) * factor
comm_count = int(np.prod(shape)) * factor
self._comm_count = comm_count

return self._comm_count
Expand Down
Loading

0 comments on commit 705d0a8

Please sign in to comment.