Skip to content

Commit

Permalink
fix prelu trt convert
Browse files Browse the repository at this point in the history
  • Loading branch information
JZZ-NOTE committed Feb 8, 2022
1 parent 24b2e8e commit 1c6dd83
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
5 changes: 3 additions & 2 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1203,8 +1203,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) {
VLOG(3) << "prelu op does not support input's dim is 1 in tensorrt.";
if (!with_dynamic_shape && x_shape.size() == 1) {
VLOG(3) << "prelu op does not support input's dim is 1 in tensorrt "
"with static shape.";
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,29 +172,30 @@ def clear_dynamic_shape():
for i in range(len(program_config.ops))
]

def generate_trt_nodes_num(attrs, dynamic_shape):
if not dynamic_shape and self.dim1 == 0 and self.dim2 == 0 and self.dim3 == 0:
return 0, 3
return 1, 2

# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 2), 1e-5
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 2), 1e-5
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5

# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 2), 1e-5
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 2), 1e-5
yield self.create_inference_config(), generate_trt_nodes_num(attrs,
True), 1e-5

def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if self.dim1 == 0 and self.dim2 == 0 and self.dim3 == 0:
return True
return False

self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT,
"Trt does not support 1-dimensional input.")

ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000:

Expand Down

1 comment on commit 1c6dd83

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.