Skip to content

Commit

Permalink
change api for support trt8
Browse files Browse the repository at this point in the history
  • Loading branch information
fengshuai03 committed Oct 27, 2021
1 parent db633af commit c0b5fbd
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions paddle/fluid/inference/tensorrt/convert/pool3d_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ namespace tensorrt {

inline void DealCeilMode(const nvinfer1::Dims &input_shape,
std::vector<int> ksize, std::vector<int> strides,
std::vector<int> paddings, nvinfer1::DimsCHW *pre_pad,
nvinfer1::DimsCHW *post_pad, int input_dims) {
std::vector<int> paddings, nvinfer1::Dims3 *pre_pad,
nvinfer1::Dims3 *post_pad, int input_dims) {
int input_depth = input_shape.d[input_dims - 3];
int input_height = input_shape.d[input_dims - 2];
int input_width = input_shape.d[input_dims - 1];
Expand Down Expand Up @@ -118,9 +118,9 @@ class Pool3dOpConverter : public OpConverter {
reduce_operation = nvinfer1::ReduceOperation::kAVG;
plugin_pool_type = plugin::Pool3DPlugin::Pool3DType::avg;
}
nvinfer1::DimsCHW nv_ksize(ksize[0], ksize[1], ksize[2]);
nvinfer1::DimsCHW nv_strides(strides[0], strides[1], strides[2]);
nvinfer1::DimsCHW nv_paddings(paddings[0], paddings[1], paddings[2]);
nvinfer1::Dims3 nv_ksize(ksize[0], ksize[1], ksize[2]);
nvinfer1::Dims3 nv_strides(strides[0], strides[1], strides[2]);
nvinfer1::Dims3 nv_paddings(paddings[0], paddings[1], paddings[2]);
nvinfer1::ILayer *layer = nullptr;
if (op_desc.HasAttr("enable_int8")) {
CHECK(op_desc.HasAttr("X_scale"));
Expand Down

0 comments on commit c0b5fbd

Please sign in to comment.