forked from koyeongmin/PINet_new
-
Notifications
You must be signed in to change notification settings - Fork 0
/
onnx_converter.py
44 lines (32 loc) · 1.21 KB
/
onnx_converter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
'''
Convert trained model into onnx.
'''
import torch
import torch.onnx
from hourglass_network import lane_detection_network
# (True)Convert to onnx mode.
# (False)Check converted onnx model mode.
convert = True
save_dir = '/media/data4/yg/PINet_new-master/CurveLanes/onnx_models/'
if convert == True:
model = lane_detection_network()
weights_path = '/media/data4/yg/PINet_new-master/CurveLanes/savefile/32_tensor(1.1001)_lane_detection_network.pkl'
# Load the weights from a file (.pth or .pkl usually)
state_dict = torch.load(weights_path)
# Load the weights now into a model net architecture.
model.load_state_dict(state_dict)
# Create the right input shape.
sample_batch_size = 1
channel = 3
height = 256
width = 512
dummy_input = torch.randn(sample_batch_size, channel, height, width)
torch.onnx.export(model, dummy_input, save_dir + "pinet_v2.onnx", verbose = True)
if convert == False:
import onnx
# Load the onnx model
model = onnx.load(save_dir + "pinet_v2.onnx")
# Check that the IR is well formed
onnx.checker.check_model(model)
# Print a human readable representation of the graph.
print(onnx.helper.printable_graph(model.graph))