Skip to content

Commit

Permalink
fix prelu trt convert
Browse files Browse the repository at this point in the history
  • Loading branch information
JZZ-NOTE committed Feb 8, 2022
1 parent 24b2e8e commit 1c6dd83
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
5 changes: 3 additions & 2 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1203,8 +1203,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) {
VLOG(3) << "prelu op does not support input's dim is 1 in tensorrt.";
if (!with_dynamic_shape && x_shape.size() == 1) {
VLOG(3) << "prelu op does not support input's dim is 1 in tensorrt "
"with static shape.";
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,29 +172,30 @@ def clear_dynamic_shape():
for i in range(len(program_config.ops))
]

def generate_trt_nodes_num(attrs, dynamic_shape):
if not dynamic_shape and self.dim1 == 0 and self.dim2 == 0 and self.dim3 == 0:
return 0, 3
return 1, 2

# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 2), 1e-5
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 2), 1e-5
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5

# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 2), 1e-5
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 2), 1e-5
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5

def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if self.dim1 == 0 and self.dim2 == 0 and self.dim3 == 0:
return True
return False

self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT,
"Trt does not support 1-dimensional input.")

ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000:

Expand Down

0 comments on commit 1c6dd83

Please sign in to comment.