Skip to content

Commit

Permalink
[Fix] Fix visualization bug in 3d pose (#2594)
Browse files Browse the repository at this point in the history
  • Loading branch information
LareinaM committed Aug 8, 2023
1 parent cb48094 commit ec30ee1
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
15 changes: 13 additions & 2 deletions demo/body3d_pose_lifter_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def parse_args():
'--save-predictions',
action='store_true',
default=False,
help='whether to save predicted results')
help='Whether to save predicted results')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
Expand Down Expand Up @@ -124,7 +124,14 @@ def parse_args():
'--use-multi-frames',
action='store_true',
default=False,
help='whether to use multi frames for inference in the 2D pose'
help='Whether to use multi frames for inference in the 2D pose'
'detection stage. Default: False.')
parser.add_argument(
'--online',
action='store_true',
default=False,
help='Inference mode. If set to True, can not use future frame'
'information when using multi frames for inference in the 2D pose'
'detection stage. Default: False.')

args = parser.parse_args()
Expand Down Expand Up @@ -405,6 +412,10 @@ def main():
'Only "PoseLifter" model is supported for the 2nd stage ' \
'(2D-to-3D lifting)'

if args.use_multi_frames:
assert 'frame_indices_test' in pose_estimator.cfg.data.test.data_cfg
indices = pose_estimator.cfg.data.test.data_cfg['frame_indices_test']

pose_lifter.cfg.visualizer.radius = args.radius
pose_lifter.cfg.visualizer.line_width = args.thickness
pose_lifter.cfg.visualizer.det_kpt_color = det_kpt_color
Expand Down
30 changes: 17 additions & 13 deletions mmpose/apis/inference_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,11 @@ def collate_pose_sequence(pose_results_2d,
pose_sequences = []
for idx in range(N):
pose_seq = PoseDataSample()
gt_instances = InstanceData()
pred_instances = InstanceData()

for k in pose_results_2d[target_frame][idx].gt_instances.keys():
gt_instances.set_field(
pose_results_2d[target_frame][idx].gt_instances[k], k)
for k in pose_results_2d[target_frame][idx].pred_instances.keys():
if k != 'keypoints':
pred_instances.set_field(
pose_results_2d[target_frame][idx].pred_instances[k], k)
gt_instances = pose_results_2d[target_frame][idx].gt_instances.clone()
pred_instances = pose_results_2d[target_frame][
idx].pred_instances.clone()
pose_seq.pred_instances = pred_instances
pose_seq.gt_instances = gt_instances

Expand Down Expand Up @@ -228,7 +223,7 @@ def collate_pose_sequence(pose_results_2d,
# replicate the right most frame
keypoints[:, frame_idx + 1:] = keypoints[:, frame_idx]
break
pose_seq.pred_instances.keypoints = keypoints
pose_seq.pred_instances.set_field(keypoints, 'keypoints')
pose_sequences.append(pose_seq)

return pose_sequences
Expand Down Expand Up @@ -276,8 +271,15 @@ def inference_pose_lifter_model(model,
bbox_center = None
bbox_scale = None

pose_results_2d_copy = []
for i, pose_res in enumerate(pose_results_2d):
pose_res_copy = []
for j, data_sample in enumerate(pose_res):
data_sample_copy = PoseDataSample()
data_sample_copy.gt_instances = data_sample.gt_instances.clone()
data_sample_copy.pred_instances = data_sample.pred_instances.clone(
)
data_sample_copy.track_id = data_sample.track_id
kpts = data_sample.pred_instances.keypoints
bboxes = data_sample.pred_instances.bboxes
keypoints = []
Expand All @@ -292,11 +294,13 @@ def inference_pose_lifter_model(model,
bbox_scale + bbox_center)
else:
keypoints.append(kpt[:, :2])
pose_results_2d[i][j].pred_instances.keypoints = np.array(
keypoints)
data_sample_copy.pred_instances.set_field(
np.array(keypoints), 'keypoints')
pose_res_copy.append(data_sample_copy)
pose_results_2d_copy.append(pose_res_copy)

pose_sequences_2d = collate_pose_sequence(pose_results_2d, with_track_id,
target_idx)
pose_sequences_2d = collate_pose_sequence(pose_results_2d_copy,
with_track_id, target_idx)

if not pose_sequences_2d:
return []
Expand Down

0 comments on commit ec30ee1

Please sign in to comment.