Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#225 from JZZ-NOTE/tune_demo
Browse files Browse the repository at this point in the history
add python tune demo
  • Loading branch information
Wangzheee committed Mar 24, 2022
2 parents 9dcbc76 + a6deece commit eb95780
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 0 deletions.
30 changes: 30 additions & 0 deletions python/paddle_trt/tuned_dynamic_shape/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
## 使用 Paddle-TRT TunedDynamicShape 能力

该文档为使用 Paddle-TRT TunedDynamicShape 的实践 demo。如果您刚接触 Paddle-TRT,推荐先访问[这里](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_trt.html)对 Paddle-TRT 有个初步认识。

### 准备环境

请您在环境中安装2.0或以上版本的 Paddle,具体的安装方式请参照[飞桨官方页面](https://www.paddlepaddle.org.cn/)的指示方式。

### 下载测试模型

下载[模型](https://paddle-inference-dist.bj.bcebos.com/Paddle-Inference-Demo/resnet50.tgz),模型为 imagenet 数据集训练得到的,如果你想获取更多的模型训练信息,请访问[这里](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification)。解压后存储到该工程的根目录。

#### TunedDynamicShape 测试

**1、首先您需要针对业务数据进行离线 tune,来获取计算图中所有中间 tensor 的 shape 范围,并将其存储在 config 中配置的 shape_range_info.pbtxt 文件中**

```
python infer_tune.py --model_file ./resnet50/inference.pdmodel --params_file ./resnet50/inference.pdiparams --tune 1
```

**2、有了离线 tune 得到的 shape 范围信息后,您可以使用该文件自动对所有的 trt 子图设置其输入的 shape 范围。**

```
python infer_tune.py --model_file ./resnet50/inference.pdmodel --params_file ./resnet50/inference.pdiparams --use_gpu 1 --use_trt 1 --tuned_dynamic_shape 1
```

### 更多链接
- [Paddle Inference使用Quick Start!](https://paddle-inference.readthedocs.io/en/latest/introduction/quick_start.html)
- [Paddle Inference C++ Api使用](https://paddle-inference.readthedocs.io/en/latest/user_guides/cxx_api.html)
- [Paddle Inference Python Api使用](https://paddle-inference.readthedocs.io/en/latest/user_guides/inference_python_api.html)
108 changes: 108 additions & 0 deletions python/paddle_trt/tuned_dynamic_shape/infer_tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import numpy as np
import argparse

from paddle.inference import Config
from paddle.inference import create_predictor
from paddle.inference import PrecisionType

shape_file = "shape_range_info.pbtxt"


def init_predictor(args):
if args.model_dir is not "":
config = Config(args.model_dir)
else:
config = Config(args.model_file, args.params_file)

config.enable_memory_optim()
if args.tune:
config.collect_shape_range_info(shape_file)
if args.use_gpu:
config.enable_use_gpu(1000, 0)
if args.use_trt:
# using dynamic shpae mode, the max_batch_size will be ignored.
config.enable_tensorrt_engine(
workspace_size=1 << 30,
max_batch_size=1,
min_subgraph_size=5,
precision_mode=PrecisionType.Float32,
use_static=False,
use_calib_mode=False)
if args.tuned_dynamic_shape:
config.enable_tuned_tensorrt_dynamic_shape(shape_file, True)
else:
# If not specific mkldnn, you can set the blas thread.
# The thread num should not be greater than the number of cores in the CPU.
config.set_cpu_math_library_num_threads(4)
config.enable_mkldnn()

predictor = create_predictor(config)
return predictor


def run(predictor, img):
# copy img data to input tensor
input_names = predictor.get_input_names()
for i, name in enumerate(input_names):
input_tensor = predictor.get_input_handle(name)
input_tensor.reshape(img[i].shape)
input_tensor.copy_from_cpu(img[i].copy())

# do the inference
predictor.run()

results = []
# get out data from output tensor
output_names = predictor.get_output_names()
for i, name in enumerate(output_names):
output_tensor = predictor.get_output_handle(name)
output_data = output_tensor.copy_to_cpu()
results.append(output_data)

return results


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_file",
type=str,
default="",
help="Model filename, Specify this when your model is a combined model."
)
parser.add_argument(
"--params_file",
type=str,
default="",
help="Parameter filename, Specify this when your model is a combined model."
)
parser.add_argument(
"--model_dir",
type=str,
default="",
help="Model dir, If you load a non-combined model, specify the directory of the model."
)
parser.add_argument(
"--use_gpu", type=int, default=0, help="Whether use gpu.")
parser.add_argument(
"--use_trt", type=int, default=0, help="Whether use trt.")
parser.add_argument(
"--tune",
type=int,
default=0,
help="Whether use tune to get shape range.")
parser.add_argument(
"--tuned_dynamic_shape",
type=int,
default=0,
help="Whether use tuned dynamic shape.")
return parser.parse_args()


if __name__ == '__main__':
args = parse_args()
pred = init_predictor(args)
for batch in [1, 2, 4]:
input = np.ones((batch, 3, 224, 224)).astype(np.float32)
result = run(pred, [input])
print("class index: ", np.argmax(result[0][0]))

0 comments on commit eb95780

Please sign in to comment.