From 1c6dd8345ce7f398097abd84d74e803227a0601f Mon Sep 17 00:00:00 2001 From: JZZ-NOTE Date: Tue, 8 Feb 2022 07:34:06 +0000 Subject: [PATCH] fix prelu trt convert --- paddle/fluid/inference/tensorrt/op_teller.cc | 5 ++-- .../ir/inference/test_trt_convert_prelu.py | 25 ++++++++++--------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 5e320a027022f..4c8d9d50965c0 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1203,8 +1203,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, auto x_var_name = desc.Input("X")[0]; auto* x_var_desc = block->FindVar(x_var_name); const auto x_shape = x_var_desc->GetShape(); - if (x_shape.size() == 1) { - VLOG(3) << "prelu op does not support input's dim is 1 in tensorrt."; + if (!with_dynamic_shape && x_shape.size() == 1) { + VLOG(3) << "prelu op does not support input's dim is 1 in tensorrt " + "with static shape."; return false; } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py index 5153476ae19f1..10109cdc73a2b 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_prelu.py @@ -172,29 +172,30 @@ def clear_dynamic_shape(): for i in range(len(program_config.ops)) ] + def generate_trt_nodes_num(attrs, dynamic_shape): + if not dynamic_shape and self.dim1 == 0 and self.dim2 == 0 and self.dim3 == 0: + return 0, 3 + return 1, 2 + # for static_shape clear_dynamic_shape() self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), (1, 2), 1e-5 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), (1, 2), 1e-5 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 # for dynamic_shape generate_dynamic_shape(attrs) self.trt_param.precision = paddle_infer.PrecisionType.Float32 - yield self.create_inference_config(), (1, 2), 1e-5 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 self.trt_param.precision = paddle_infer.PrecisionType.Half - yield self.create_inference_config(), (1, 2), 1e-5 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 def add_skip_trt_case(self): - def teller1(program_config, predictor_config): - if self.dim1 == 0 and self.dim2 == 0 and self.dim3 == 0: - return True - return False - - self.add_skip_case(teller1, SkipReasons.TRT_NOT_SUPPORT, - "Trt does not support 1-dimensional input.") - ver = paddle_infer.get_trt_compile_version() if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000: