Skip to content

Commit

Permalink
[OpAttr]Adapt tensor axis for reduce_min/max/mean/sum/prod (#45078)
Browse files Browse the repository at this point in the history
* [OpAttr]Adapt tensor axis for reduce_min/max/mean/sum/prod
  • Loading branch information
0x45f committed Aug 30, 2022
1 parent e221a60 commit 32f42e9
Show file tree
Hide file tree
Showing 59 changed files with 513 additions and 189 deletions.
7 changes: 4 additions & 3 deletions paddle/fluid/operators/reduce_ops/reduce_max_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ class ReduceMaxOpMaker : public ops::ReduceOpMaker {
virtual std::string GetOpType() const { return "Reduce reduce_max"; }
};

DECLARE_INFER_SHAPE_FUNCTOR(reduce_max,
ReduceMaxInferShapeFunctor,
PD_INFER_META(phi::ReduceInferMetaBase));
DECLARE_INFER_SHAPE_FUNCTOR(
reduce_max,
ReduceMaxInferShapeFunctor,
PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase));

REGISTER_OPERATOR(
reduce_max,
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/operators/reduce_ops/reduce_mean_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,10 @@ class __reduce_meanMaker__ : public ops::ReduceOpMaker {
virtual std::string GetOpType() const { return "Reduce reduce_mean"; }
};

DECLARE_INFER_SHAPE_FUNCTOR(reduce_mean,
ReduceMeanInferShapeFunctor,
PD_INFER_META(phi::ReduceInferMetaBase));
DECLARE_INFER_SHAPE_FUNCTOR(
reduce_mean,
ReduceMeanInferShapeFunctor,
PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase));

REGISTER_OPERATOR(reduce_mean,
ops::ReduceOp,
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/operators/reduce_ops/reduce_min_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ class ReduceMinOpMaker : public ops::ReduceOpMaker {
virtual std::string GetOpType() const { return "Reduce reduce_min"; }
};

DECLARE_INFER_SHAPE_FUNCTOR(reduce_min,
ReduceMinInferShapeFunctor,
PD_INFER_META(phi::ReduceInferMetaBase));
DECLARE_INFER_SHAPE_FUNCTOR(
reduce_min,
ReduceMinInferShapeFunctor,
PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase));

REGISTER_OPERATOR(
reduce_min,
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/operators/reduce_ops/reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,8 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
"Must be in the range [-rank(input), rank(input)). "
"If `dim[i] < 0`, the dims[i] to reduce is `rank + dims[i]`. "
"Note that reducing on the first dim will make the LoD info lost.")
.SetDefault({0});
.SetDefault({0})
.SupportTensor();
AddAttr<bool>("keep_dim",
"(bool, default false) "
"If true, retain the reduced dimension with length 1.")
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/operators/reduce_ops/reduce_prod_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ class ReduceProdOpMaker : public ops::ReduceOpMaker {
virtual std::string GetOpType() const { return "Reduce reduce_prod"; }
};

DECLARE_INFER_SHAPE_FUNCTOR(reduce_prod,
ReduceProdInferShapeFunctor,
PD_INFER_META(phi::ReduceInferMetaBase));
DECLARE_INFER_SHAPE_FUNCTOR(
reduce_prod,
ReduceProdInferShapeFunctor,
PD_INFER_META(phi::ReduceIntArrayAxisInferMetaBase));

REGISTER_OPERATOR(
reduce_prod,
Expand Down
18 changes: 9 additions & 9 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1668,10 +1668,10 @@
func : matrix_rank_tol

- api : max
args : (Tensor x, int64_t[] dims={}, bool keep_dim=false)
args : (Tensor x, IntArray dims={}, bool keep_dim=false)
output : Tensor(out)
infer_meta :
func : ReduceInferMeta
func : ReduceIntArrayAxisInferMeta
kernel :
func : max
backward : max_grad
Expand Down Expand Up @@ -1713,10 +1713,10 @@
backward : maxout_grad

- api : mean
args : (Tensor x, int64_t[] dims={}, bool keep_dim=false)
args : (Tensor x, IntArray dims={}, bool keep_dim=false)
output : Tensor(out)
infer_meta :
func : ReduceInferMeta
func : ReduceIntArrayAxisInferMeta
kernel :
func : mean
backward : mean_grad
Expand Down Expand Up @@ -1762,10 +1762,10 @@
backward : meshgrid_grad

- api : min
args : (Tensor x, int64_t[] dims={}, bool keep_dim=false)
args : (Tensor x, IntArray dims={}, bool keep_dim=false)
output : Tensor(out)
infer_meta :
func : ReduceInferMeta
func : ReduceIntArrayAxisInferMeta
kernel :
func : min
backward : min_grad
Expand Down Expand Up @@ -2091,10 +2091,10 @@
backward : reciprocal_grad

- api : reduce_prod
args : (Tensor x, int64_t[] dims, bool keep_dim, bool reduce_all)
args : (Tensor x, IntArray dims, bool keep_dim, bool reduce_all)
output : Tensor
infer_meta :
func : ReduceInferMetaBase
func : ReduceIntArrayAxisInferMetaBase
kernel :
func : prod_raw
backward : reduce_prod_grad
Expand Down Expand Up @@ -2555,7 +2555,7 @@
backward : subtract_grad

- api : sum
args : (Tensor x, int64_t[] dims={}, DataType out_dtype=DataType::UNDEFINED, bool keep_dim=false)
args : (Tensor x, IntArray dims={}, DataType out_dtype=DataType::UNDEFINED, bool keep_dim=false)
output : Tensor(out)
infer_meta :
func : SumInferMeta
Expand Down
28 changes: 14 additions & 14 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1451,8 +1451,8 @@
func : matrix_power_grad

- backward_api : max_grad
forward: max (Tensor x, int64_t[] dims={}, bool keep_dim=false) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] dims={}, bool keep_dim=false, bool reduce_all=false)
forward: max (Tensor x, IntArray dims={}, bool keep_dim=false) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, IntArray dims={}, bool keep_dim=false, bool reduce_all=false)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
Expand Down Expand Up @@ -1509,14 +1509,14 @@
func : mean_all_grad

- backward_api : mean_double_grad
forward: mean_grad (Tensor x, Tensor grad_out, int64_t[] dims={}, bool keep_dim=false, bool reduce_all = false) -> Tensor(grad_x)
args : (Tensor grad_x_grad, int64_t[] dims={}, bool keep_dim=false)
forward: mean_grad (Tensor x, Tensor grad_out, IntArray dims={}, bool keep_dim=false, bool reduce_all = false) -> Tensor(grad_x)
args : (Tensor grad_x_grad, IntArray dims={}, bool keep_dim=false)
output : Tensor(grad_out_grad)
invoke : mean(grad_x_grad, dims, keep_dim)

- backward_api : mean_grad
forward: mean (Tensor x, int64_t[] dims={}, bool keep_dim=false) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int64_t[] dims={}, bool keep_dim=false, bool reduce_all=false)
forward: mean (Tensor x, IntArray dims={}, bool keep_dim=false) -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray dims={}, bool keep_dim=false, bool reduce_all=false)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
Expand All @@ -1536,8 +1536,8 @@
func : meshgrid_grad

- backward_api : min_grad
forward: min (Tensor x, int64_t[] dims={}, bool keep_dim=false) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] dims={}, bool keep_dim=false, bool reduce_all=false)
forward: min (Tensor x, IntArray dims={}, bool keep_dim=false) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, IntArray dims={}, bool keep_dim=false, bool reduce_all=false)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
Expand Down Expand Up @@ -1849,8 +1849,8 @@
inplace : (out_grad -> x_grad)

- backward_api : reduce_prod_grad
forward : reduce_prod (Tensor x, int64_t[] dims, bool keep_dim, bool reduce_all) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] dims, bool keep_dim, bool reduce_all)
forward : reduce_prod (Tensor x, IntArray dims, bool keep_dim, bool reduce_all) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, IntArray dims, bool keep_dim, bool reduce_all)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
Expand Down Expand Up @@ -2386,14 +2386,14 @@
inplace : (out_grad -> x_grad)

- backward_api : sum_double_grad
forward : sum_grad (Tensor x, Tensor grad_out, int64_t[] dims, bool keep_dim, bool reduce_all=false) -> Tensor(grad_x)
args : (Tensor grad_x_grad, int64_t[] dims={}, bool keep_dim=false)
forward : sum_grad (Tensor x, Tensor grad_out, IntArray dims, bool keep_dim, bool reduce_all=false) -> Tensor(grad_x)
args : (Tensor grad_x_grad, IntArray dims={}, bool keep_dim=false)
output : Tensor(grad_out_grad)
invoke : sum(grad_x_grad, dims, grad_x_grad.dtype(), keep_dim)

- backward_api : sum_grad
forward : sum (Tensor x, int64_t[] dims={}, DataType out_dtype=DataType::UNDEFINED, bool keep_dim=false) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int64_t[] dims, bool keep_dim, bool reduce_all=false)
forward : sum (Tensor x, IntArray dims={}, DataType out_dtype=DataType::UNDEFINED, bool keep_dim=false) -> Tensor(out)
args : (Tensor x, Tensor out_grad, IntArray dims, bool keep_dim, bool reduce_all=false)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
Expand Down
74 changes: 68 additions & 6 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2656,6 +2656,33 @@ DDim ReduceInferDim(const MetaTensor& x,
return out_dim;
}

DDim ReduceInferDimForIntArrayAxis(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all) {
std::vector<int64_t> vec_axis = axis.GetData();
std::vector<int64_t> vec_dim;
if (reduce_all) {
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), 1);
} else {
vec_dim = {1};
}
} else {
if (keep_dim) {
vec_dim = std::vector<int64_t>(x.dims().size(), -1);
} else {
auto x_rank = static_cast<size_t>(x.dims().size());
if (vec_axis.size() >= x_rank) {
vec_dim = {-1};
} else {
vec_dim = std::vector<int64_t>(x.dims().size() - vec_axis.size(), -1);
}
}
}
return phi::make_ddim(vec_dim);
}

void ReduceInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
bool keep_dim,
Expand All @@ -2678,6 +2705,34 @@ void ReduceInferMetaBase(const MetaTensor& x,
out->set_layout(x.layout());
}

void ReduceIntArrayAxisInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config) {
if (config.is_runtime || !axis.FromTensor()) {
ReduceInferMetaBase(x, axis.GetData(), keep_dim, reduce_all, out);
} else {
DDim out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
out->set_dims(out_dim);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
}

void ReduceIntArrayAxisInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
MetaTensor* out,
MetaConfig config) {
bool reduce_all = false;
if (axis.size() == 0) {
reduce_all = true;
}
ReduceIntArrayAxisInferMetaBase(x, axis, keep_dim, reduce_all, out, config);
}

void RepeatInterleaveInferMeta(const MetaTensor& x,
int repeats,
int dim,
Expand Down Expand Up @@ -3354,24 +3409,31 @@ void StridedSliceInferMeta(const MetaTensor& x,
api.yaml
*/
void SumInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
const IntArray& axis,
DataType dtype,
bool keep_dim,
MetaTensor* out) {
MetaTensor* out,
MetaConfig config) {
bool reduce_all = false;
if (axis.size() == 0) {
reduce_all = true;
}
SumRawInferMeta(x, axis, keep_dim, reduce_all, dtype, out);
SumRawInferMeta(x, axis, keep_dim, reduce_all, dtype, out, config);
}

void SumRawInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
DataType dtype,
MetaTensor* out) {
DDim out_dim = ReduceInferDim(x, axis, keep_dim, reduce_all);
MetaTensor* out,
MetaConfig config) {
DDim out_dim;
if (config.is_runtime || !axis.FromTensor()) {
out_dim = ReduceInferDim(x, axis.GetData(), keep_dim, reduce_all);
} else {
out_dim = ReduceInferDimForIntArrayAxis(x, axis, keep_dim, reduce_all);
}

DataType out_dtype;
if (dtype != DataType::UNDEFINED) {
Expand Down
23 changes: 19 additions & 4 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,19 @@ void ReduceInferMetaBase(const MetaTensor& x,
bool reduce_all,
MetaTensor* out);

void ReduceIntArrayAxisInferMetaBase(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
MetaTensor* out,
MetaConfig config = MetaConfig());

void ReduceIntArrayAxisInferMeta(const MetaTensor& x,
const IntArray& axis,
bool keep_dim,
MetaTensor* out,
MetaConfig config = MetaConfig());

void RepeatInterleaveInferMeta(const MetaTensor& x,
int repeats,
int dim,
Expand Down Expand Up @@ -477,17 +490,19 @@ void StridedSliceInferMeta(const MetaTensor& x,
MetaConfig config = MetaConfig());

void SumInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
const IntArray& axis,
DataType dtype,
bool keep_dim,
MetaTensor* out);
MetaTensor* out,
MetaConfig config = MetaConfig());

void SumRawInferMeta(const MetaTensor& x,
const std::vector<int64_t>& axis,
const IntArray& axis,
bool keep_dim,
bool reduce_all,
DataType dtype,
MetaTensor* out);
MetaTensor* out,
MetaConfig config = MetaConfig());

void SvdInferMeta(const MetaTensor& x,
bool full_matrices,
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/kernels/cpu/graph_send_ue_recv_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void CalculateXGrad(const Context& ctx,
DenseTensor x_grad_out = phi::Sum<T, Context>(
ctx,
x_grad_v2,
reduce_idx,
phi::IntArray(reduce_idx),
paddle::experimental::CppTypeToDataType<T>::Type(),
true);
memcpy(x_grad, x_grad_out.data<T>(), x_grad_out.numel() * sizeof(T));
Expand Down Expand Up @@ -131,7 +131,7 @@ void CalculateXGrad(const Context& ctx,
DenseTensor x_grad_out = phi::Sum<T, Context>(
ctx,
x_grad_v2,
reduce_idx,
phi::IntArray(reduce_idx),
paddle::experimental::CppTypeToDataType<T>::Type(),
true);
memcpy(x_grad, x_grad_out.data<T>(), x_grad_out.numel() * sizeof(T));
Expand Down Expand Up @@ -166,7 +166,7 @@ void CalculateXGrad(const Context& ctx,
DenseTensor x_grad_out = phi::Sum<T, Context>(
ctx,
x_grad_v2,
reduce_idx,
phi::IntArray(reduce_idx),
paddle::experimental::CppTypeToDataType<T>::Type(),
true);
memcpy(x_grad, x_grad_out.data<T>(), x_grad_out.numel() * sizeof(T));
Expand Down Expand Up @@ -220,7 +220,7 @@ void CalculateXGrad(const Context& ctx,
DenseTensor x_grad_out = phi::Sum<T, Context>(
ctx,
x_grad_v2,
reduce_idx,
phi::IntArray(reduce_idx),
paddle::experimental::CppTypeToDataType<T>::Type(),
true);
memcpy(x_grad, x_grad_out.data<T>(), x_grad_out.numel() * sizeof(T));
Expand Down
Loading

0 comments on commit 32f42e9

Please sign in to comment.