Skip to content

Commit

Permalink
--feat=update dataset and add ut
Browse files Browse the repository at this point in the history
  • Loading branch information
xiexinch committed Dec 26, 2023
1 parent a79ab3b commit 0a10f08
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 153 deletions.
271 changes: 118 additions & 153 deletions mmpose/datasets/datasets/wholebody3d/h3wb_dataset.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import logging
import os.path as osp
from typing import List, Tuple

import numpy as np
from mmengine.fileio import get_local_path
from mmengine.logging import print_log

from mmpose.registry import DATASETS
from ..body3d import Human36mDataset


@DATASETS.register_module()
class H36MWholeBodyDataset(Human36mDataset):
"""H36MWholeBodyDataset dataset for pose estimation.
METAINFO: dict = dict(from_file='configs/_base_/datasets/h3wb.py')
"""Human3.6M 3D WholeBody Dataset.
"H3WB: Human3.6M 3D WholeBody Dataset and Benchmark", ICCV'2023.
More details can be found in the `paper
<https://arxiv.org/abs/2211.15692>`__.
H36M-WholeBody keypoints::
Expand All @@ -27,7 +28,6 @@ class H36MWholeBodyDataset(Human36mDataset):
Args:
ann_file (str): Annotation file path. Default: ''.
joint_2d_src (str): Annotation file for 2D keypoint.
seq_len (int): Number of frames in a sequence. Default: 1.
seq_step (int): The interval for extracting frames from the video.
Default: 1.
Expand All @@ -46,9 +46,11 @@ class H36MWholeBodyDataset(Human36mDataset):
should be one of the following options:
- ``'gt'``: load from the annotation file
- ``'detection'``: load from a detection result file of 2D keypoint
- ``'pipeline'``: the information will be generated by the pipeline
Default: ``'gt'``.
- ``'detection'``: load from a detection
result file of 2D keypoint
- 'pipeline': the information will be generated by the pipeline
Default: ``'gt'``.
keypoint_2d_det_file (str, optional): The 2D keypoint detection file.
If set, 2d keypoint loaded from this file will be used instead of
ground-truth keypoints. This setting is only when
Expand Down Expand Up @@ -88,161 +90,124 @@ class H36MWholeBodyDataset(Human36mDataset):
image. Default: 1000.
"""

METAINFO: dict = dict(from_file='configs/_base_/datasets/h3wb.py')
def __init__(self, test_mode: bool = False, **kwargs):

def __init__(self, ann_file: str, data_root: str, joint_2d_src: str,
data_prefix: dict, **kwargs):
self.ann_file = ann_file
self.data_root = data_root
self.data_prefix = data_prefix
self.joint_2d_src = joint_2d_src
self.camera_order_id = ['54138969', '55011271', '58860488', '60457274']
if not test_mode:
self.subjects = ['S1', 'S5', 'S6']
else:
self.subjects = ['S7']

super().__init__(
ann_file=ann_file,
data_root=data_root,
data_prefix=data_prefix,
**kwargs)
super().__init__(test_mode=test_mode, **kwargs)

def _load_ann_file(self, ann_file: str) -> dict:
"""Load annotation file to get image information.
Args:
ann_file (str): Annotation file path.
Returns:
dict: Annotation information.
"""

with get_local_path(ann_file) as local_path:
self.ann_data = json.load(open(local_path))
with get_local_path(osp.join(self.data_root,
self.joint_2d_src)) as local_path:
self.joint_2d_ann = json.load(open(local_path))
self._process_image_names(self.ann_data)

def _process_image_names(self, ann_data: dict) -> List[str]:
"""Process image names."""
image_folder = self.data_prefix['img']
img_names = [ann_data[i]['image_path'] for i in ann_data]
image_paths = []
for image_name in img_names:
scene, _, sub, frame = image_name.split('/')
frame, suffix = frame.split('.')
frame_id = f'{int(frame.split("_")[-1]) + 1:06d}'
sub = '_'.join(sub.split(' '))
path = f'{scene}/{scene}_{sub}/{scene}_{sub}_{frame_id}.{suffix}'
if not osp.exists(osp.join(self.data_root, image_folder, path)):
print_log(
f'Failed to read image {path}.',
logger='current',
level=logging.WARN)
continue
image_paths.append(path)
self.image_names = image_paths
data = np.load(local_path, allow_pickle=True)

self.ann_data = data['train_data'].item()
self.camera_data = data['metadata'].item()

def get_sequence_indices(self) -> List[List[int]]:
self.ann_data['imgname'] = self.image_names
return super().get_sequence_indices()
return []

def _load_annotations(self) -> Tuple[List[dict], List[dict]]:
num_keypoints = self.metainfo['num_keypoints']

img_names = np.array(self.image_names)
num_imgs = len(img_names)

scales = np.zeros(num_imgs, dtype=np.float32)
centers = np.zeros((num_imgs, 2), dtype=np.float32)

kpts_3d, kpts_2d = [], []
for k in self.ann_data.keys():
if not isinstance(self.ann_data[k], dict):
continue
ann, ann_2d = self.ann_data[k], self.joint_2d_ann[k]
kpts_2d_i, kpts_3d_i = self._get_kpts(ann, ann_2d)
kpts_3d.append(kpts_3d_i)
kpts_2d.append(kpts_2d_i)

kpts_3d = np.concatenate(kpts_3d, axis=0)
kpts_2d = np.concatenate(kpts_2d, axis=0)
kpts_visible = np.ones_like(kpts_2d[..., 0], dtype=np.float32)

# Normalize 3D keypoints like H36M
# Ref: https://github.com/open-mmlab/mmpose/blob/main/tools/dataset_converters/preprocess_h36m.py#L324 # noqa
kpts_3d /= 1000.0

if self.factor_file:
with get_local_path(self.factor_file) as local_path:
factors = np.load(local_path).astype(np.float32)
else:
factors = np.zeros((kpts_3d.shape[0], ), dtype=np.float32)

instance_list = []
for idx, frame_ids in enumerate(self.sequence_indices):
expected_num_frames = self.seq_len
if self.multiple_target:
expected_num_frames = self.multiple_target

assert len(frame_ids) == (expected_num_frames), (
f'Expected `frame_ids` == {expected_num_frames}, but '
f'got {len(frame_ids)} ')

_img_names = img_names[frame_ids]
_kpts_2d = kpts_2d[frame_ids]
_kpts_3d = kpts_3d[frame_ids]
_kpts_visible = kpts_visible[frame_ids]
factor = factors[frame_ids].astype(np.float32)

target_idx = [-1] if self.causal else [int(self.seq_len) // 2]
if self.multiple_target > 0:
target_idx = list(range(self.multiple_target))

instance_info = {
'num_keypoints': num_keypoints,
'keypoints': _kpts_2d,
'keypoints_3d': _kpts_3d,
'keypoints_visible': _kpts_visible,
'keypoints_3d_visible': _kpts_visible,
'scale': scales[idx],
'center': centers[idx].astype(np.float32).reshape(1, -1),
'factor': factor,
'id': idx,
'category_id': 1,
'iscrowd': 0,
'img_paths': list(_img_names),
'img_ids': frame_ids,
'lifting_target': _kpts_3d[target_idx],
'lifting_target_visible': _kpts_visible[target_idx],
'target_img_path': _img_names[target_idx],
}

if self.camera_param_file:
_cam_param = self.get_camera_param(_img_names[0])
else:
# Use the max value of camera parameters in Human3.6M dataset
_cam_param = {
'w': 1000,
'h': 1002,
'f': np.array([[1149.67569987], [1148.79896857]]),
'c': np.array([[519.81583718], [515.45148698]])
}
instance_info['camera_param'] = _cam_param
instance_list.append(instance_info)

image_list = []
if self.data_mode == 'bottomup':
for idx, img_name in enumerate(img_names):
img_info = self.get_img_info(idx, img_name)
image_list.append(img_info)

return instance_list, image_list
instance_id = 0
for subject in self.subjects:
actions = self.ann_data[subject].keys()
for act in actions:
for cam in self.camera_order_id:
if cam not in self.ann_data[subject][act]:
continue
keypoints_2d = self.ann_data[subject][act][cam]['pose_2d']
keypoints_3d = self.ann_data[subject][act][cam][
'camera_3d']
num_keypoints = keypoints_2d.shape[1]

camera_param = self.camera_data[subject][cam]
camera_param = {
'K': camera_param['K'][0, :2, ...],
'R': camera_param['R'][0],
'T': camera_param['T'].reshape(3, 1),
'Distortion': camera_param['Distortion'][0]
}

seq_step = 1
_len = (self.seq_len - 1) * seq_step + 1
_indices = list(
range(len(self.ann_data[subject][act]['frame_id'])))
seq_indices = [
_indices[i:(i + _len):seq_step]
for i in list(range(0,
len(_indices) - _len + 1))
]

for idx, frame_ids in enumerate(seq_indices):
expected_num_frames = self.seq_len
if self.multiple_target:
expected_num_frames = self.multiple_target

assert len(frame_ids) == (expected_num_frames), (
f'Expected `frame_ids` == {expected_num_frames}, but ' # noqa
f'got {len(frame_ids)} ')

_kpts_2d = keypoints_2d[frame_ids]
_kpts_3d = keypoints_3d[frame_ids]

target_idx = [-1] if self.causal else [
int(self.seq_len) // 2
]
if self.multiple_target > 0:
target_idx = list(range(self.multiple_target))

instance_info = {
'num_keypoints':
num_keypoints,
'keypoints':
_kpts_2d,
'keypoints_3d':
_kpts_3d / 1000,
'keypoints_visible':
np.ones_like(_kpts_2d[..., 0], dtype=np.float32),
'keypoints_3d_visible':
np.ones_like(_kpts_2d[..., 0], dtype=np.float32),
'scale':
np.zeros((1, 1), dtype=np.float32),
'center':
np.zeros((1, 2), dtype=np.float32),
'factor':
np.zeros((1, 1), dtype=np.float32),
'id':
instance_id,
'category_id':
1,
'iscrowd':
0,
'camera_param':
camera_param,
'img_paths': [
f'{subject}/{act}/{cam}/{i:06d}.jpg'
for i in frame_ids
],
'img_ids':
frame_ids,
'lifting_target':
_kpts_3d[target_idx] / 1000,
'lifting_target_visible':
np.ones_like(_kpts_2d[..., 0],
dtype=np.float32)[target_idx],
}
instance_list.append(instance_info)

if self.data_mode == 'bottomup':
for idx, img_name in enumerate(
instance_info['img_paths']):
img_info = self.get_img_info(idx, img_name)
image_list.append(img_info)

instance_id += 1

def _get_kpts(self, ann: dict,
ann_2d: dict) -> Tuple[np.ndarray, np.ndarray]:
"""Get 2D keypoints and 3D keypoints from annotation."""
kpts = ann['keypoints_3d']
kpts_2d = ann_2d['keypoints_2d']
kpts_3d = np.array([[j['x'], j['y'], j['z']] for _, j in kpts.items()],
dtype=np.float32)[np.newaxis, ...]
kpts_2d = np.array([[j['x'], j['y']] for _, j in kpts_2d.items()],
dtype=np.float32)[np.newaxis, ...]
return kpts_2d, kpts_3d
return instance_list, image_list
Binary file added tests/data/h3wb/h3wb_train_sub.npz
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import numpy as np

from mmpose.datasets.datasets.wholebody3d import H36MWholeBodyDataset


class TestH36MWholeBodyDataset(TestCase):

def build_h3wb_dataset(self, **kwargs):

cfg = dict(
ann_file='h3wb_train_sub.npz',
data_mode='topdown',
data_root='tests/data/h3wb',
pipeline=[])

cfg.update(kwargs)
return H36MWholeBodyDataset(**cfg)

def check_data_info_keys(self, data_info: dict):
expected_keys = dict(
img_paths=list,
keypoints=np.ndarray,
keypoints_3d=np.ndarray,
scale=np.ndarray,
center=np.ndarray,
id=int)

for key, type_ in expected_keys.items():
self.assertIn(key, data_info)
self.assertIsInstance(data_info[key], type_, key)

def test_metainfo(self):
dataset = self.build_h3wb_dataset()
# test dataset_name
self.assertEqual(dataset.metainfo['dataset_name'], 'h3wb')

# test number of keypoints
num_keypoints = 133
self.assertEqual(dataset.metainfo['num_keypoints'], num_keypoints)
self.assertEqual(
len(dataset.metainfo['keypoint_colors']), num_keypoints)
self.assertEqual(
len(dataset.metainfo['dataset_keypoint_weights']), num_keypoints)

# test some extra metainfo
self.assertEqual(
len(dataset.metainfo['skeleton_links']),
len(dataset.metainfo['skeleton_link_colors']))

def test_topdown(self):
# test topdown training
dataset = self.build_h3wb_dataset(data_mode='topdown')
dataset.full_init()
self.assertEqual(len(dataset), 3)
self.check_data_info_keys(dataset[0])

# test topdown testing
dataset = self.build_h3wb_dataset(data_mode='topdown', test_mode=True)
dataset.full_init()
self.assertEqual(len(dataset), 1)
self.check_data_info_keys(dataset[0])

# test topdown training with sequence config
dataset = self.build_h3wb_dataset(
data_mode='topdown',
seq_len=1,
seq_step=1,
causal=False,
pad_video_seq=True)
dataset.full_init()
self.assertEqual(len(dataset), 3)
self.check_data_info_keys(dataset[0])

0 comments on commit 0a10f08

Please sign in to comment.