-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
118 lines (101 loc) · 5.07 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import logging
import os
import sys
import importlib
import argparse
import munch
import yaml
from utils.vis_utils import plot_single_pcd
from utils.train_utils import *
from dataset import ShapeNetH5
def test():
# Set GPU to use
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
dataset_test = ShapeNetH5(train=False, npoints=args.num_points)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size,
shuffle=False, num_workers=int(args.workers))
dataset_length = len(dataset_test)
logging.info('Length of test dataset:%d', dataset_length)
# load model
model_module = importlib.import_module('.%s' % args.model_name, 'models')
net = torch.nn.DataParallel(model_module.Model(args))
net.cuda()
net.module.load_state_dict(torch.load(args.load_model)['net_state_dict'])
logging.info("%s's previous weights loaded." % args.model_name)
net.eval()
metrics = ['cd_p', 'cd_t', 'emd', 'f1']
test_loss_meters = {m: AverageValueMeter() for m in metrics}
test_loss_cat = torch.zeros([16, 4], dtype=torch.float32).cuda()
cat_num = torch.ones([8, 1], dtype=torch.float32).cuda() * 150 * 26
novel_cat_num = torch.ones([8, 1], dtype=torch.float32).cuda() * 50 * 26
cat_num = torch.cat((cat_num, novel_cat_num), dim=0)
cat_name = ['airplane', 'cabinet', 'car', 'chair', 'lamp', 'sofa', 'table', 'watercraft',
'bed', 'bench', 'bookshelf', 'bus', 'guitar', 'motorbike', 'pistol', 'skateboard']
idx_to_plot = [i for i in range(0, 1600, 75)]
logging.info('Testing...')
if args.save_vis:
save_gt_path = os.path.join(log_dir, 'pics', 'gt')
save_partial_path = os.path.join(log_dir, 'pics', 'partial')
save_completion_path = os.path.join(log_dir, 'pics', 'completion')
os.makedirs(save_gt_path, exist_ok=True)
os.makedirs(save_partial_path, exist_ok=True)
os.makedirs(save_completion_path, exist_ok=True)
with torch.no_grad():
for i, data in enumerate(dataloader_test):
#label, inputs_cpu, gt_cpu = data
inputs_cpu = data['partial_cloud']
gt_cpu = data['gtcloud']
label = data['label']
# mean_feature = None
inputs = inputs_cpu.float().cuda()
gt = gt_cpu.float().cuda()
inputs = inputs.transpose(2, 1).contiguous()
# result_dict = net(inputs, gt, is_training=False, mean_feature=mean_feature)
result_dict = net(inputs, gt, is_training=False)
for k, v in test_loss_meters.items():
v.update(result_dict[k].mean().item())
for j, l in enumerate(label):
for ind, m in enumerate(metrics):
test_loss_cat[int(l), ind] += result_dict[m][int(j)]
if i % args.step_interval_to_print == 0:
logging.info('test [%d/%d]' % (i, dataset_length / args.batch_size))
if args.save_vis:
for j in range(args.batch_size):
idx = i * args.batch_size + j
if idx in idx_to_plot:
pic = 'object_%d.png' % idx
plot_single_pcd(result_dict['out2'][j].cpu().numpy(), os.path.join(save_completion_path, pic))
plot_single_pcd(gt_cpu[j], os.path.join(save_gt_path, pic))
plot_single_pcd(inputs_cpu[j].cpu().numpy(), os.path.join(save_partial_path, pic))
logging.info('Loss per category:')
category_log = ''
for i in range(16):
category_log += '\ncategory name: %s' % (cat_name[i])
for ind, m in enumerate(metrics):
scale_factor = 1 if m == 'f1' else 10000
category_log += ' %s: %f' % (m, test_loss_cat[i, ind] / cat_num[i] * scale_factor)
logging.info(category_log)
logging.info('Overview results:')
overview_log = ''
for metric, meter in test_loss_meters.items():
overview_log += '%s: %f ' % (metric, meter.avg)
logging.info(overview_log)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Test config file')
parser.add_argument('--gpu', dest='gpu_id', help='GPU device to use', default='0,1', type=str)
parser.add_argument('-c', '--config', help='path to config file', required=True)
parse_args = parser.parse_args()
config_path = parse_args.config
args = munch.munchify(yaml.safe_load(open(config_path)))
if parse_args.gpu_id is not None:
args.gpu = parse_args.gpu_id
if not args.load_model:
raise ValueError('Model path must be provided to load model!')
exp_name = os.path.basename(args.load_model)
log_dir = os.path.dirname(args.load_model)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler(os.path.join(log_dir, 'test.log')),
logging.StreamHandler(sys.stdout)])
test()