From 5ad3b1ecf1a4e2c4894fa1f3ebbe6438257e2339 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Wed, 20 Jul 2022 10:32:11 +0000 Subject: [PATCH 1/3] fix_fill_constant --- paddle/fluid/inference/tensorrt/convert/fill_constant_op.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddle/fluid/inference/tensorrt/convert/fill_constant_op.cc b/paddle/fluid/inference/tensorrt/convert/fill_constant_op.cc index 3bbbfe0374325..4d524c01b783f 100644 --- a/paddle/fluid/inference/tensorrt/convert/fill_constant_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fill_constant_op.cc @@ -32,6 +32,10 @@ class FillConstantOpConverter : public OpConverter { PADDLE_GET_CONST(std::string, op_desc.GetAttr("str_value")); std::vector shape = PADDLE_GET_CONST(std::vector, op_desc.GetAttr("shape")); + if (str_value == "") { + float value = PADDLE_GET_CONST(float, op_desc.GetAttr("value")); + str_value = std::to_string(value); + } std::unique_ptr out_tensor(new framework::Tensor()); out_tensor->Resize(phi::make_ddim(shape)); nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT; From d4338f79bce759257f5a9a8f000313bf9e533f88 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Wed, 20 Jul 2022 13:07:08 +0000 Subject: [PATCH 2/3] fix_fill_constant --- .../ir/inference/test_trt_convert_fill_constant.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_fill_constant.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_fill_constant.py index 84ee70782acc2..169086c262a1f 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_fill_constant.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_fill_constant.py @@ -42,8 +42,14 @@ def generate_shapelist_data(attrs: List[Dict[str, Any]]): for dtype in [5, 2, 3]: for str_value in ["2", "23", "-1"]: self.num_input = num_input + value = float(str_value) + if np.random.choice([True, False]): + str_value = str_value + else: + str_value = "" dics = [{ "str_value": str_value, + "value": value, "shape": shape, "dtype": dtype }, { From 902cb684819a6fc25c7737be76c52bd7f3a98aef Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 21 Jul 2022 00:12:27 +0000 Subject: [PATCH 3/3] fix_ernie --- .../unittests/ir/inference/test_trt_convert_fill_constant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_fill_constant.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_fill_constant.py index 169086c262a1f..cc686be6d8a83 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_fill_constant.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_fill_constant.py @@ -43,7 +43,7 @@ def generate_shapelist_data(attrs: List[Dict[str, Any]]): for str_value in ["2", "23", "-1"]: self.num_input = num_input value = float(str_value) - if np.random.choice([True, False]): + if np.random.choice([False, True]): str_value = str_value else: str_value = ""