Skip to content

Commit

Permalink
[bug fix] fix unfold runtime bug (#38819)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghostxsl committed Jan 10, 2022
1 parent a8afed6 commit 5c35750
Showing 1 changed file with 41 additions and 34 deletions.
75 changes: 41 additions & 34 deletions paddle/fluid/operators/unfold_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,40 +143,47 @@ class UnfoldOp : public framework::OperatorWithKernel {
"but recieved dilations_height: %d dilations_width: %d.",
dilations[0], dilations[1]));

std::vector<int> out_dims;
out_dims.push_back(in_dims[0]);

int output_channels = in_dims[1] * kernel_sizes[0] * kernel_sizes[1];
out_dims.push_back(output_channels);

int output_height =
CalcOutputSize(in_dims[2], kernel_sizes[0], dilations[0], paddings[0],
paddings[2], strides[0]);
int output_width = CalcOutputSize(in_dims[3], kernel_sizes[1], dilations[1],
paddings[1], paddings[3], strides[1]);
// check output height and width
PADDLE_ENFORCE_GT(
output_height, 0,
platform::errors::InvalidArgument(
"The sliding blocks calculated from input spatial size (%d, %d), "
"kernel_sizes (%d, %d), strides (%d, %d), dilations (%d, %d), "
"is (%d, %d), which should be a positive integer.",
in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1],
strides[0], strides[1], dilations[0], dilations[1], output_height,
output_width));
PADDLE_ENFORCE_GT(
output_width, 0,
platform::errors::InvalidArgument(
"The sliding blocks calculated from input spatial size (%d, %d), "
"kernel_sizes (%d, %d), strides (%d, %d), dilations (%d, %d), "
"is (%d, %d), which should be a positive integer.",
in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1],
strides[0], strides[1], dilations[0], dilations[1], output_height,
output_width));
int output_col_length = output_height * output_width;
out_dims.push_back(output_col_length);

ctx->SetOutputDim("Y", framework::make_ddim(out_dims));
bool contain_unknown_dim = framework::contain_unknown_dim(in_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
std::vector<int> out_dims;
out_dims.push_back(in_dims[0]);

int output_channels = in_dims[1] * kernel_sizes[0] * kernel_sizes[1];
out_dims.push_back(output_channels);

int output_height =
CalcOutputSize(in_dims[2], kernel_sizes[0], dilations[0], paddings[0],
paddings[2], strides[0]);
int output_width =
CalcOutputSize(in_dims[3], kernel_sizes[1], dilations[1], paddings[1],
paddings[3], strides[1]);
// check output height and width
PADDLE_ENFORCE_GT(
output_height, 0,
platform::errors::InvalidArgument(
"The sliding blocks calculated from input spatial size "
"(%d, %d), kernel_sizes (%d, %d), strides (%d, %d), "
"dilations (%d, %d), is (%d, %d), which should be a "
"positive integer.",
in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1],
strides[0], strides[1], dilations[0], dilations[1], output_height,
output_width));
PADDLE_ENFORCE_GT(
output_width, 0,
platform::errors::InvalidArgument(
"The sliding blocks calculated from input spatial size "
"(%d, %d), kernel_sizes (%d, %d), strides (%d, %d), "
"dilations (%d, %d), is (%d, %d), which should be a "
"positive integer.",
in_dims[2], in_dims[3], kernel_sizes[0], kernel_sizes[1],
strides[0], strides[1], dilations[0], dilations[1], output_height,
output_width));
int output_col_length = output_height * output_width;
out_dims.push_back(output_col_length);

ctx->SetOutputDim("Y", framework::make_ddim(out_dims));
}
}

protected:
Expand Down

0 comments on commit 5c35750

Please sign in to comment.