Skip to content

Commit

Permalink
[Paddle-TRT] fix_fill_constant (PaddlePaddle#44481)
Browse files Browse the repository at this point in the history
* fix_fill_constant

* fix_fill_constant

* fix_ernie
  • Loading branch information
zhoutianzi666 authored and Aurelius84 committed Jul 29, 2022
1 parent 912664d commit cad9755
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/fill_constant_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class FillConstantOpConverter : public OpConverter {
PADDLE_GET_CONST(std::string, op_desc.GetAttr("str_value"));
std::vector<int64_t> shape =
PADDLE_GET_CONST(std::vector<int64_t>, op_desc.GetAttr("shape"));
if (str_value == "") {
float value = PADDLE_GET_CONST(float, op_desc.GetAttr("value"));
str_value = std::to_string(value);
}
std::unique_ptr<framework::Tensor> out_tensor(new framework::Tensor());
out_tensor->Resize(phi::make_ddim(shape));
nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,14 @@ def generate_shapelist_data(attrs: List[Dict[str, Any]]):
for dtype in [5, 2, 3]:
for str_value in ["2", "23", "-1"]:
self.num_input = num_input
value = float(str_value)
if np.random.choice([False, True]):
str_value = str_value
else:
str_value = ""
dics = [{
"str_value": str_value,
"value": value,
"shape": shape,
"dtype": dtype
}, {
Expand Down

0 comments on commit cad9755

Please sign in to comment.