Skip to content

Commit

Permalink
change api to support trt8 in pool3d_op_convert (PaddlePaddle#36783)
Browse files Browse the repository at this point in the history
* change api for support trt8

* fix:change api
  • Loading branch information
feng_shuai authored and piotrekobi committed Nov 3, 2021
1 parent 201c923 commit c3e81ba
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions paddle/fluid/inference/tensorrt/convert/pool3d_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ namespace tensorrt {

inline void DealCeilMode(const nvinfer1::Dims &input_shape,
std::vector<int> ksize, std::vector<int> strides,
std::vector<int> paddings, nvinfer1::DimsCHW *pre_pad,
nvinfer1::DimsCHW *post_pad, int input_dims) {
std::vector<int> paddings, nvinfer1::Dims3 *pre_pad,
nvinfer1::Dims3 *post_pad, int input_dims) {
int input_depth = input_shape.d[input_dims - 3];
int input_height = input_shape.d[input_dims - 2];
int input_width = input_shape.d[input_dims - 1];
Expand All @@ -56,15 +56,15 @@ inline void DealCeilMode(const nvinfer1::Dims &input_shape,
1;

if (floor_d_output_size != ceil_d_output_size) {
post_pad->c() = strides[0] - 1;
post_pad->d[0] = strides[0] - 1;
}

if (floor_h_output_size != ceil_h_output_size) {
post_pad->h() = strides[1] - 1;
post_pad->d[1] = strides[1] - 1;
}

if (floor_w_output_size != ceil_w_output_size) {
post_pad->w() = strides[2] - 1;
post_pad->d[2] = strides[2] - 1;
}
}

Expand Down Expand Up @@ -118,9 +118,9 @@ class Pool3dOpConverter : public OpConverter {
reduce_operation = nvinfer1::ReduceOperation::kAVG;
plugin_pool_type = plugin::Pool3DPlugin::Pool3DType::avg;
}
nvinfer1::DimsCHW nv_ksize(ksize[0], ksize[1], ksize[2]);
nvinfer1::DimsCHW nv_strides(strides[0], strides[1], strides[2]);
nvinfer1::DimsCHW nv_paddings(paddings[0], paddings[1], paddings[2]);
nvinfer1::Dims3 nv_ksize(ksize[0], ksize[1], ksize[2]);
nvinfer1::Dims3 nv_strides(strides[0], strides[1], strides[2]);
nvinfer1::Dims3 nv_paddings(paddings[0], paddings[1], paddings[2]);
nvinfer1::ILayer *layer = nullptr;
if (op_desc.HasAttr("enable_int8")) {
CHECK(op_desc.HasAttr("X_scale"));
Expand Down

0 comments on commit c3e81ba

Please sign in to comment.