Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 4 No.17】Add cummax / cummin API to Paddle #53546

Merged
merged 31 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
c4f6a10
add API and op for cummax/min
Patrick-Star125 May 2, 2023
ff601cc
test pass stage 1
Patrick-Star125 May 6, 2023
fe0e522
format code
Patrick-Star125 May 7, 2023
90c50d3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Patrick-Star125 May 29, 2023
32cc8c5
fix bug
Patrick-Star125 May 29, 2023
8c75a58
refine test sample
Patrick-Star125 May 30, 2023
77cd93e
fix bug
Patrick-Star125 May 30, 2023
9c8fe91
fix bug
Patrick-Star125 May 30, 2023
c9ca0df
fix bug
Patrick-Star125 May 30, 2023
e5102e1
fix bug
Patrick-Star125 May 31, 2023
cecfa73
fix bug
Patrick-Star125 May 31, 2023
343d5d9
format
Patrick-Star125 May 31, 2023
d438ec9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Patrick-Star125 May 31, 2023
f398100
fix bug
Patrick-Star125 May 31, 2023
1919e10
fix bug
Patrick-Star125 May 31, 2023
1b0f4e7
fix bug
Patrick-Star125 May 31, 2023
24c27a3
fix bug
Patrick-Star125 Jun 1, 2023
d17bf57
fix bug
Patrick-Star125 Jun 1, 2023
b3d1863
fix bug
Patrick-Star125 Jun 1, 2023
d8aefa2
fix bug
Patrick-Star125 Jun 1, 2023
60f08f6
merge main
Patrick-Star125 Jun 1, 2023
1ea7738
fix bug
Patrick-Star125 Jun 1, 2023
1ee7d80
format
Patrick-Star125 Jun 1, 2023
722912f
format
Patrick-Star125 Jun 1, 2023
51042b7
foremat
Patrick-Star125 Jun 1, 2023
ec992f1
format
Patrick-Star125 Jun 1, 2023
18ccdaf
adjust test file position
Patrick-Star125 Jun 5, 2023
401a944
adjust test file position
Patrick-Star125 Jun 5, 2023
406ba76
restore test place
Patrick-Star125 Jun 7, 2023
4011a80
correct out description
Patrick-Star125 Jun 9, 2023
f60a11d
format
Patrick-Star125 Jun 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,26 @@
func : cross_grad
data_type : out_grad

- backward_op : cummax_grad
forward : cummax(Tensor x, int axis=-1, int dtype=3) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int axis, int dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样是没问题的,这里的x可以优化掉吗,似乎只会用到输入的维度

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

删除后在test_check_grad处会报参数数量不匹配,但是我无法通过删除test_check_grad参数解决这个问题

请问删除x是必要的吗,这里x传引用似乎不会有内存额外占用问题

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看起来是kernel没有同步改?功能上是不必要的,这里主要是能优化掉的话,反向计算的时就少一个对x的引用,可能就能释放掉x

output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : cummax_grad

- backward_op : cummin_grad
forward : cummin(Tensor x, int axis=-1, int dtype=3) -> Tensor(out), Tensor(indices)
args : (Tensor x, Tensor indices, Tensor out_grad, int axis, int dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,kernel也一样

output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : cummin_grad

- backward_op : cumprod_grad
forward : cumprod (Tensor x, int dim) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int dim)
Expand Down
18 changes: 18 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,24 @@
data_type : input
backward : cross_entropy_with_softmax_grad

- op : cummax
args : (Tensor x, int axis=-1, int dtype=3)
output : Tensor(out), Tensor(indices)
infer_meta :
func : CumWithIndicesInferMeta
kernel :
func : cummax
backward : cummax_grad

- op : cummin
args : (Tensor x, int axis=-1, int dtype=3)
output : Tensor(out), Tensor(indices)
infer_meta :
func : CumWithIndicesInferMeta
kernel :
func : cummin
backward : cummin_grad

- op : cumprod
args : (Tensor x, int dim)
output : Tensor(out)
Expand Down
63 changes: 63 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,69 @@ void CumScalarAxisInferMeta(const MetaTensor& x,
CumInferMeta(x, axis.to<int>(), flatten, exclusive, reverse, out);
}

void CumWithIndicesInferMeta(const MetaTensor& x,
int axis,
int dtype,
MetaTensor* out,
MetaTensor* indices) {
auto x_dims = x.dims();
auto indices_type = phi::TransToPhiDataType(dtype);
PADDLE_ENFORCE_EQ(
(indices_type == DataType::INT32 || indices_type == DataType::INT64),
true,
phi::errors::InvalidArgument("dtype of indices must be int32 or int64"));

if (indices_type == DataType::INT32) {
int _axis;
if (axis < 0) {
_axis = axis + x_dims.size();
} else {
_axis = axis;
}
PADDLE_ENFORCE_LT(
phi::vectorize(x_dims)[_axis],
INT32_MAX,
phi::errors::OutOfRange(
"cummax with axis %ld may be overflow, set dtype int64 to continue",
axis));
}

if (x_dims.size() > 0) {
PADDLE_ENFORCE_GE(
axis,
-x_dims.size(),
phi::errors::OutOfRange(
"axis is out of range (expected to be in range of [%ld, "
"%ld), but got %ld).",
-(x_dims.size()),
x_dims.size(),
axis));
PADDLE_ENFORCE_LT(
axis,
x_dims.size(),
phi::errors::OutOfRange(
"axis is out of range (expected to be in range of [%ld, "
"%ld), but got %ld).",
-(x_dims.size()),
x_dims.size(),
axis));
} else {
PADDLE_ENFORCE_EQ(
(axis == 0 || axis == -1),
true,
errors::InvalidArgument("The axis must be -1 or 0 in 0D Tensor, "
"but the value given is %d.",
axis));
}

out->set_dims(x_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
indices->set_dims(x_dims);
indices->set_dtype(indices_type);
indices->share_lod(x);
}

void CropInferMeta(const MetaTensor& x,
const IntArray& shape,
const IntArray& offsets,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ void CumScalarAxisInferMeta(const MetaTensor& x,
bool reverse,
MetaTensor* out);

void CumWithIndicesInferMeta(const MetaTensor& x,
int axis,
int dtype,
MetaTensor* out,
MetaTensor* indices);

void DecodeJpegInferMeta(const MetaTensor& x,
const std::string& mode,
MetaTensor* out);
Expand Down
91 changes: 91 additions & 0 deletions paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/cum_maxmin_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/gather_scatter_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"

namespace phi {

template <typename T, typename Context>
void CummaxGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
if (axis < 0) {
axis = axis + x.dims().size();
}
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
} else if (indices_type == DataType::INT64) {
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
}
}

template <typename T, typename Context>
void CumminGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& indices,
const DenseTensor& out_grad,
int axis,
int dtype,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, x_grad, static_cast<T>(0));
if (axis < 0) {
axis = axis + x.dims().size();
}
auto indices_type = phi::TransToPhiDataType(dtype);
if (indices_type == DataType::INT32) {
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
} else if (indices_type == DataType::INT64) {
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
}
}

} // namespace phi

PD_REGISTER_KERNEL(cummax_grad,
CPU,
ALL_LAYOUT,
phi::CummaxGradKernel,
float,
double,
int32_t,
int64_t) {}

PD_REGISTER_KERNEL(cummin_grad,
CPU,
ALL_LAYOUT,
phi::CumminGradKernel,
float,
double,
int32_t,
int64_t) {}
Loading