-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add unpool2d op & Expose max_unpool2d API #35056
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
d870d1a
add maxunppol2d op, test=develop
tink2123 1f5eab3
fix typo, test=develop
tink2123 5266ba7
fix unpool unitest, test=develop
tink2123 099cd32
fix unpool code-example, test=develop
tink2123 f350b74
fix for unpool_op_unittest,test=develop
tink2123 aa97717
fix example code, test=develop
tink2123 192d039
add noqa:F401, test=develop
tink2123 fca6240
fix converage, test=develop
tink2123 e2d04e9
fix unitest for unpool, test=develop
tink2123 cc21c42
rename unpool2d to unpool, test=develop
tink2123 b6dda59
rename unpool2d to unpool, test=develop
tink2123 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,48 +25,27 @@ __global__ void KernelUnpool2dMax(const int nthreads, const T* input_data, | |
const int channels, T* output_data, | ||
const int output_height, | ||
const int output_width) { | ||
int in_n_stride = input_height * input_width * channels; | ||
int in_c_stride = input_height * input_width; | ||
int out_n_stride = output_height * output_width * channels; | ||
int out_c_stride = output_height * output_width; | ||
int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
int offset = blockDim.x * gridDim.x; | ||
for (int i = index; i < nthreads; i += offset) { | ||
int bidx = i / in_n_stride; | ||
int boffset = i % in_n_stride; | ||
int cidx = boffset / in_c_stride; | ||
int out_offset = bidx * out_n_stride + cidx * out_c_stride; | ||
int out_index = indices_data[i]; | ||
PADDLE_ENFORCE(out_index < out_c_stride, | ||
"out_index < out_c_stride. Expected %ld < %ld, but got " | ||
"%ld >= %ld. Please check input value.", | ||
out_index, out_c_stride, out_index, out_c_stride); | ||
output_data[out_offset + out_index] = input_data[i]; | ||
CUDA_KERNEL_LOOP(linearIndex, nthreads) { | ||
int c = (linearIndex / input_width / input_height) % channels; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 异常数据注意enforce,否则安全扫描可能有问题 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thx,下一个PR中添加异常数据检查 |
||
int n = linearIndex / input_width / input_height / channels; | ||
output_data += (n * channels + c) * output_height * output_width; | ||
int maxind = indices_data[linearIndex]; | ||
output_data[maxind] = input_data[linearIndex]; | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void KernelUnpool2dMaxGrad( | ||
const int nthreads, const T* input_data, const int* indices_data, | ||
const int input_height, const int input_width, const int channels, | ||
const T* output_data, const T* output_grad, const int output_height, | ||
const int output_width, T* input_grad) { | ||
int in_n_stride = input_height * input_width * channels; | ||
int in_c_stride = input_height * input_width; | ||
int out_n_stride = output_height * output_width * channels; | ||
int out_c_stride = output_height * output_width; | ||
int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
int offset = blockDim.x * gridDim.x; | ||
for (int i = index; i < nthreads; i += offset) { | ||
int bidx = i / in_n_stride; | ||
int boffset = i % in_n_stride; | ||
int cidx = boffset / in_c_stride; | ||
int out_offset = bidx * out_n_stride + cidx * out_c_stride; | ||
int out_index = indices_data[i]; | ||
PADDLE_ENFORCE(out_index < out_c_stride, | ||
"out_index < out_c_stride. Expected %ld < %ld, but got " | ||
"%ld >= %ld. Please check input value.", | ||
out_index, out_c_stride, out_index, out_c_stride); | ||
input_grad[i] = output_grad[out_offset + out_index]; | ||
CUDA_KERNEL_LOOP(linearIndex, nthreads) { | ||
int c = (linearIndex / input_width / input_height) % channels; | ||
int n = linearIndex / input_width / input_height / channels; | ||
output_grad += (n * channels + c) * output_height * output_width; | ||
int maxind = indices_data[linearIndex]; | ||
input_grad[linearIndex] = output_grad[maxind]; | ||
} | ||
} | ||
/* | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为啥去掉这块的enforce呢?建议再check下数据数据检查