-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add unpool2d op & Expose max_unpool2d API #35056
Changes from 6 commits
d870d1a
1f5eab3
5266ba7
099cd32
f350b74
aa97717
192d039
fca6240
e2d04e9
cc21c42
b6dda59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,48 +25,27 @@ __global__ void KernelUnpool2dMax(const int nthreads, const T* input_data, | |
const int channels, T* output_data, | ||
const int output_height, | ||
const int output_width) { | ||
int in_n_stride = input_height * input_width * channels; | ||
int in_c_stride = input_height * input_width; | ||
int out_n_stride = output_height * output_width * channels; | ||
int out_c_stride = output_height * output_width; | ||
int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
int offset = blockDim.x * gridDim.x; | ||
for (int i = index; i < nthreads; i += offset) { | ||
int bidx = i / in_n_stride; | ||
int boffset = i % in_n_stride; | ||
int cidx = boffset / in_c_stride; | ||
int out_offset = bidx * out_n_stride + cidx * out_c_stride; | ||
int out_index = indices_data[i]; | ||
PADDLE_ENFORCE(out_index < out_c_stride, | ||
"out_index < out_c_stride. Expected %ld < %ld, but got " | ||
"%ld >= %ld. Please check input value.", | ||
out_index, out_c_stride, out_index, out_c_stride); | ||
output_data[out_offset + out_index] = input_data[i]; | ||
CUDA_KERNEL_LOOP(linearIndex, nthreads) { | ||
int c = (linearIndex / input_width / input_height) % channels; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 异常数据注意enforce,否则安全扫描可能有问题 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thx,下一个PR中添加异常数据检查 |
||
int n = linearIndex / input_width / input_height / channels; | ||
output_data += (n * channels + c) * output_height * output_width; | ||
int maxind = indices_data[linearIndex]; | ||
output_data[maxind] = input_data[linearIndex]; | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void KernelUnpool2dMaxGrad( | ||
const int nthreads, const T* input_data, const int* indices_data, | ||
const int input_height, const int input_width, const int channels, | ||
const T* output_data, const T* output_grad, const int output_height, | ||
const int output_width, T* input_grad) { | ||
int in_n_stride = input_height * input_width * channels; | ||
int in_c_stride = input_height * input_width; | ||
int out_n_stride = output_height * output_width * channels; | ||
int out_c_stride = output_height * output_width; | ||
int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
int offset = blockDim.x * gridDim.x; | ||
for (int i = index; i < nthreads; i += offset) { | ||
int bidx = i / in_n_stride; | ||
int boffset = i % in_n_stride; | ||
int cidx = boffset / in_c_stride; | ||
int out_offset = bidx * out_n_stride + cidx * out_c_stride; | ||
int out_index = indices_data[i]; | ||
PADDLE_ENFORCE(out_index < out_c_stride, | ||
"out_index < out_c_stride. Expected %ld < %ld, but got " | ||
"%ld >= %ld. Please check input value.", | ||
out_index, out_c_stride, out_index, out_c_stride); | ||
input_grad[i] = output_grad[out_offset + out_index]; | ||
CUDA_KERNEL_LOOP(linearIndex, nthreads) { | ||
int c = (linearIndex / input_width / input_height) % channels; | ||
int n = linearIndex / input_width / input_height / channels; | ||
output_grad += (n * channels + c) * output_height * output_width; | ||
int maxind = indices_data[linearIndex]; | ||
input_grad[linearIndex] = output_grad[maxind]; | ||
} | ||
} | ||
/* | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,9 +16,9 @@ limitations under the License. */ | |
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP_CUDA_KERNEL( | ||
unpool, ops::UnpoolKernel<paddle::platform::CUDADeviceContext, float>, | ||
unpool2d, ops::UnpoolKernel<paddle::platform::CUDADeviceContext, float>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 因为改名,需要确认下是否之前有API使用了unpool There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在全景图里做了搜索,确认没有api使用过unpool |
||
ops::UnpoolKernel<paddle::platform::CUDADeviceContext, double>); | ||
REGISTER_OP_CUDA_KERNEL( | ||
unpool_grad, | ||
unpool2d_grad, | ||
ops::UnpoolGradKernel<paddle::platform::CUDADeviceContext, float>, | ||
ops::UnpoolGradKernel<paddle::platform::CUDADeviceContext, double>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为啥去掉这块的enforce呢?建议再check下数据数据检查