Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unpool2d op & Expose max_unpool2d API #35056

Merged
merged 11 commits into from
Aug 27, 2021
1 change: 1 addition & 0 deletions paddle/fluid/operators/math/unpooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Unpool2dMaxFunctor<platform::CPUDeviceContext, T> {
for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i];

PADDLE_ENFORCE_LT(
index, output_feasize,
platform::errors::InvalidArgument(
Expand Down
47 changes: 13 additions & 34 deletions paddle/fluid/operators/math/unpooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,48 +25,27 @@ __global__ void KernelUnpool2dMax(const int nthreads, const T* input_data,
const int channels, T* output_data,
const int output_height,
const int output_width) {
int in_n_stride = input_height * input_width * channels;
int in_c_stride = input_height * input_width;
int out_n_stride = output_height * output_width * channels;
int out_c_stride = output_height * output_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int bidx = i / in_n_stride;
int boffset = i % in_n_stride;
int cidx = boffset / in_c_stride;
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
int out_index = indices_data[i];
PADDLE_ENFORCE(out_index < out_c_stride,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥去掉这块的enforce呢?建议再check下数据数据检查

"out_index < out_c_stride. Expected %ld < %ld, but got "
"%ld >= %ld. Please check input value.",
out_index, out_c_stride, out_index, out_c_stride);
output_data[out_offset + out_index] = input_data[i];
CUDA_KERNEL_LOOP(linearIndex, nthreads) {
int c = (linearIndex / input_width / input_height) % channels;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

异常数据注意enforce,否则安全扫描可能有问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx,下一个PR中添加异常数据检查

int n = linearIndex / input_width / input_height / channels;
output_data += (n * channels + c) * output_height * output_width;
int maxind = indices_data[linearIndex];
output_data[maxind] = input_data[linearIndex];
}
}

template <typename T>
__global__ void KernelUnpool2dMaxGrad(
const int nthreads, const T* input_data, const int* indices_data,
const int input_height, const int input_width, const int channels,
const T* output_data, const T* output_grad, const int output_height,
const int output_width, T* input_grad) {
int in_n_stride = input_height * input_width * channels;
int in_c_stride = input_height * input_width;
int out_n_stride = output_height * output_width * channels;
int out_c_stride = output_height * output_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int bidx = i / in_n_stride;
int boffset = i % in_n_stride;
int cidx = boffset / in_c_stride;
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
int out_index = indices_data[i];
PADDLE_ENFORCE(out_index < out_c_stride,
"out_index < out_c_stride. Expected %ld < %ld, but got "
"%ld >= %ld. Please check input value.",
out_index, out_c_stride, out_index, out_c_stride);
input_grad[i] = output_grad[out_offset + out_index];
CUDA_KERNEL_LOOP(linearIndex, nthreads) {
int c = (linearIndex / input_width / input_height) % channels;
int n = linearIndex / input_width / input_height / channels;
output_grad += (n * channels + c) * output_height * output_width;
int maxind = indices_data[linearIndex];
input_grad[linearIndex] = output_grad[maxind];
}
}
/*
Expand Down
23 changes: 17 additions & 6 deletions paddle/fluid/operators/unpool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"unpooling_type",
"(string), unpooling type, can be \"max\" for max-unpooling ")
.InEnum({"max"});
AddAttr<std::vector<int>>("output_size",
"(vector, optional). The shape of output.")
.SetDefault({0, 0});
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("NCHW");
AddComment(R"DOC(
Input shape is: $(N, C_{in}, H_{in}, W_{in})$, Output shape is:
$(N, C_{out}, H_{out}, W_{out})$, where
Expand Down Expand Up @@ -93,6 +103,8 @@ class UnpoolOp : public framework::OperatorWithKernel {
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> output_size =
ctx->Attrs().Get<std::vector<int>>("output_size");
PADDLE_ENFORCE_EQ(in_x_dims.size() == 4, true,
platform::errors::InvalidArgument(
"Unpool Intput(X) must be of 4-dimensional, but "
Expand All @@ -111,8 +123,7 @@ class UnpoolOp : public framework::OperatorWithKernel {
if (!ctx->IsRuntime() && in_x_dims[i + 2] <= 0) {
output_shape.push_back(-1);
} else {
output_shape.push_back(UnpoolOutputSize(in_x_dims[i + 2], ksize[i],
paddings[i], strides[i]));
output_shape.push_back(output_size[i]);
}
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
Expand Down Expand Up @@ -156,15 +167,15 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(unpool, ops::UnpoolOp, ops::Unpool2dOpMaker,
REGISTER_OPERATOR(unpool2d, ops::UnpoolOp, ops::Unpool2dOpMaker,
ops::UnpoolOpGradMaker<paddle::framework::OpDesc>,
ops::UnpoolOpGradMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(unpool_grad, ops::UnpoolOpGrad);
REGISTER_OPERATOR(unpool2d_grad, ops::UnpoolOpGrad);
REGISTER_OP_CPU_KERNEL(
unpool, ops::UnpoolKernel<paddle::platform::CPUDeviceContext, float>,
unpool2d, ops::UnpoolKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnpoolKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
unpool_grad,
unpool2d_grad,
ops::UnpoolGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnpoolGradKernel<paddle::platform::CPUDeviceContext, double>);
4 changes: 2 additions & 2 deletions paddle/fluid/operators/unpool_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ limitations under the License. */

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
unpool, ops::UnpoolKernel<paddle::platform::CUDADeviceContext, float>,
unpool2d, ops::UnpoolKernel<paddle::platform::CUDADeviceContext, float>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为改名,需要确认下是否之前有API使用了unpool

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在全景图里做了搜索,确认没有api使用过unpool

ops::UnpoolKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
unpool_grad,
unpool2d_grad,
ops::UnpoolGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnpoolGradKernel<paddle::platform::CUDADeviceContext, double>);
Loading