-
Notifications
You must be signed in to change notification settings - Fork 22
/
demo_video.py
114 lines (94 loc) · 3.62 KB
/
demo_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import cv2
import numpy as np
import argparse
import time
import torch
from KittiCalibration import KittiCalibration
from KittiVideo import KittiVideo
from visualizer import Visualizer
from BiSeNetv2.model.BiseNetv2 import BiSeNetV2
from BiSeNetv2.utils.utils import preprocessing_kitti, postprocessing
from pointpainting import PointPainter
dev = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(dev)
def main(args):
# Semantic Segmentation
bisenetv2 = BiSeNetV2()
checkpoint = torch.load(args.weights_path, map_location=dev)
bisenetv2.load_state_dict(checkpoint['bisenetv2'], strict=False)
bisenetv2.eval()
bisenetv2.to(device)
# Fusion
painter = PointPainter()
video = KittiVideo(
video_root=args.video_path,
calib_root=args.calib_path
)
visualizer = Visualizer(args.mode)
frames = []
if args.mode == '3d':
frame_shape = (1280, 720)
else:
frame_shape = (750, 900)
avg_time = 0
for i in range(len(video)):
t1 = time.time()
image, pointcloud, calib = video[i]
# print(image.shape, pointcloud.shape)
input_image = preprocessing_kitti(image)
semantic = bisenetv2(input_image)
semantic = postprocessing(semantic)
painted_pointcloud = painter.paint(pointcloud, semantic, calib)
if args.mode == '3d':
screenshot = visualizer.visuallize_pointcloud(painted_pointcloud, blocking=False)
print(screenshot.shape)
frames.append(screenshot)
else:
color_image = visualizer.get_colored_image(image, semantic)
if args.mode == 'img':
frames.append(color_image)
cv2.imshow('color_image', color_image)
elif args.mode == '2d':
scene_2D = visualizer.get_scene_2D(color_image, painted_pointcloud, calib)
frames.append(scene_2D)
# cv2.imshow('scene', scene_2D)
# if cv2.waitKey(0) == 27:
# cv2.destroyAllWindows()
# break
# if i == 20:
# break
avg_time += (time.time()-t1)
print(f'{i} sample')