Skip to content

Commit

Permalink
--update=add param descriptoin
Browse files Browse the repository at this point in the history
  • Loading branch information
xiexinch committed Dec 20, 2023
1 parent f9896df commit 8b7c336
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 15 deletions.
2 changes: 1 addition & 1 deletion mmpose/codecs/image_pose_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def decode(self,

if target_root is not None and target_root.size > 0:
keypoints = keypoints + target_root
if self.remove_root:
if self.remove_root and len(self.root_index) == 1:
keypoints = np.insert(
keypoints, self.root_index, target_root, axis=1)
scores = np.ones(keypoints.shape[:-1], dtype=np.float32)
Expand Down
2 changes: 1 addition & 1 deletion mmpose/codecs/video_pose_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def decode(self,

if target_root is not None and target_root.size > 0:
keypoints = keypoints + target_root
if self.remove_root:
if self.remove_root and len(self.root_index) == 1:
keypoints = np.insert(
keypoints, self.root_index, target_root, axis=1)
scores = np.ones(keypoints.shape[:-1], dtype=np.float32)
Expand Down
19 changes: 6 additions & 13 deletions mmpose/datasets/datasets/wholebody3d/h3wb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class H36MWholeBodyDataset(Human36mDataset):
Args:
ann_file (str): Annotation file path. Default: ''.
joint_2d_src (str): Annotation file for 2D keypoint.
seq_len (int): Number of frames in a sequence. Default: 1.
seq_step (int): The interval for extracting frames from the video.
Default: 1.
Expand All @@ -45,11 +46,9 @@ class H36MWholeBodyDataset(Human36mDataset):
should be one of the following options:
- ``'gt'``: load from the annotation file
- ``'detection'``: load from a detection
result file of 2D keypoint
- 'pipeline': the information will be generated by the pipeline
Default: ``'gt'``.
- ``'detection'``: load from a detection result file of 2D keypoint
- ``'pipeline'``: the information will be generated by the pipeline
Default: ``'gt'``.
keypoint_2d_det_file (str, optional): The 2D keypoint detection file.
If set, 2d keypoint loaded from this file will be used instead of
ground-truth keypoints. This setting is only when
Expand Down Expand Up @@ -91,18 +90,12 @@ class H36MWholeBodyDataset(Human36mDataset):

METAINFO: dict = dict(from_file='configs/_base_/datasets/h3wb.py')

def __init__(self,
ann_file: str,
data_root: str,
data_prefix: dict,
joint_2d_src: str,
normalize_with_dataset_stats: bool = False,
**kwargs):
def __init__(self, ann_file: str, data_root: str, joint_2d_src: str,
data_prefix: dict, **kwargs):
self.ann_file = ann_file
self.data_root = data_root
self.data_prefix = data_prefix
self.joint_2d_src = joint_2d_src
self.normalize_with_dataset_stats = normalize_with_dataset_stats

super().__init__(
ann_file=ann_file,
Expand Down

0 comments on commit 8b7c336

Please sign in to comment.