diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index ddee4e0d682b0..f9b7331d8bdd1 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -486,7 +486,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, BOOST_GET_CONST(std::vector, desc.GetAttr("axis")); if (!with_dynamic_shape && axis[0] != 0) return false; if (axis.size() >= nvinfer1::Dims::MAX_DIMS) return false; - if (axis[0] == 0 && axis.size() == 2) return false; auto* block = desc.Block(); if (block == nullptr) { @@ -498,7 +497,9 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, auto x_var_name = desc.Input("X")[0]; auto* x_var_desc = block->FindVar(x_var_name); const auto x_shape = x_var_desc->GetShape(); + if (axis.size() != x_shape.size()) return false; int dims = x_shape.size(); + std::vector perm(nvinfer1::Dims::MAX_DIMS); for (int i = 0; i < dims; i++) { perm[i] = axis[i]; @@ -517,6 +518,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, if (!is_valid_permutation(dims, perm)) { VLOG(3) << "Invalid permutation dimensions for trt transpose op " "converter: duplicate or out of bound."; + return false; } } if (op_type == "flatten2" || op_type == "flatten") { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_transpose.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_transpose.py index 31b4d027f1780..87e81396ab411 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_transpose.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_transpose.py @@ -146,19 +146,7 @@ def generate_trt_nodes_num(attrs, dynamic_shape): yield self.create_inference_config(), generate_trt_nodes_num(attrs, True), 1e-5 - def add_skip_trt_case(self): - def teller1(program_config, predictor_config): - if program_config.ops[0].attrs['axis'] == [0, 1]: - return True - return False - - self.add_skip_case( - teller1, SkipReasons.TRT_NOT_IMPLEMENTED, - "INPUT AXIS [0, 1] NOT SUPPORT: we need to add support in the future" - ) - def test(self): - self.add_skip_trt_case() self.run_test()