From 8b7c336e18d58ca0e3c6b47c994c3e007059945c Mon Sep 17 00:00:00 2001 From: xiexinch Date: Wed, 20 Dec 2023 16:38:30 +0800 Subject: [PATCH] --update=add param descriptoin --- mmpose/codecs/image_pose_lifting.py | 2 +- mmpose/codecs/video_pose_lifting.py | 2 +- .../datasets/wholebody3d/h3wb_dataset.py | 19 ++++++------------- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/mmpose/codecs/image_pose_lifting.py b/mmpose/codecs/image_pose_lifting.py index 70df5f45a2..1665d88e1d 100644 --- a/mmpose/codecs/image_pose_lifting.py +++ b/mmpose/codecs/image_pose_lifting.py @@ -272,7 +272,7 @@ def decode(self, if target_root is not None and target_root.size > 0: keypoints = keypoints + target_root - if self.remove_root: + if self.remove_root and len(self.root_index) == 1: keypoints = np.insert( keypoints, self.root_index, target_root, axis=1) scores = np.ones(keypoints.shape[:-1], dtype=np.float32) diff --git a/mmpose/codecs/video_pose_lifting.py b/mmpose/codecs/video_pose_lifting.py index 8e3d15ca1c..5a5a7b1983 100644 --- a/mmpose/codecs/video_pose_lifting.py +++ b/mmpose/codecs/video_pose_lifting.py @@ -238,7 +238,7 @@ def decode(self, if target_root is not None and target_root.size > 0: keypoints = keypoints + target_root - if self.remove_root: + if self.remove_root and len(self.root_index) == 1: keypoints = np.insert( keypoints, self.root_index, target_root, axis=1) scores = np.ones(keypoints.shape[:-1], dtype=np.float32) diff --git a/mmpose/datasets/datasets/wholebody3d/h3wb_dataset.py b/mmpose/datasets/datasets/wholebody3d/h3wb_dataset.py index c7ecd4be55..424fbb08ac 100644 --- a/mmpose/datasets/datasets/wholebody3d/h3wb_dataset.py +++ b/mmpose/datasets/datasets/wholebody3d/h3wb_dataset.py @@ -27,6 +27,7 @@ class H36MWholeBodyDataset(Human36mDataset): Args: ann_file (str): Annotation file path. Default: ''. + joint_2d_src (str): Annotation file for 2D keypoint. seq_len (int): Number of frames in a sequence. Default: 1. seq_step (int): The interval for extracting frames from the video. Default: 1. @@ -45,11 +46,9 @@ class H36MWholeBodyDataset(Human36mDataset): should be one of the following options: - ``'gt'``: load from the annotation file - - ``'detection'``: load from a detection - result file of 2D keypoint - - 'pipeline': the information will be generated by the pipeline - - Default: ``'gt'``. + - ``'detection'``: load from a detection result file of 2D keypoint + - ``'pipeline'``: the information will be generated by the pipeline + Default: ``'gt'``. keypoint_2d_det_file (str, optional): The 2D keypoint detection file. If set, 2d keypoint loaded from this file will be used instead of ground-truth keypoints. This setting is only when @@ -91,18 +90,12 @@ class H36MWholeBodyDataset(Human36mDataset): METAINFO: dict = dict(from_file='configs/_base_/datasets/h3wb.py') - def __init__(self, - ann_file: str, - data_root: str, - data_prefix: dict, - joint_2d_src: str, - normalize_with_dataset_stats: bool = False, - **kwargs): + def __init__(self, ann_file: str, data_root: str, joint_2d_src: str, + data_prefix: dict, **kwargs): self.ann_file = ann_file self.data_root = data_root self.data_prefix = data_prefix self.joint_2d_src = joint_2d_src - self.normalize_with_dataset_stats = normalize_with_dataset_stats super().__init__( ann_file=ann_file,