diff --git a/.gitignore b/.gitignore index 1b18d31b7f..2b337460f3 100644 --- a/.gitignore +++ b/.gitignore @@ -126,6 +126,7 @@ docs/**/modelzoo.md !tests/data/**/*.log.json !tests/data/**/*.pth !tests/data/**/*.npy +!tests/data/**/vis/ # Pytorch *.pth diff --git a/README.md b/README.md index 6d0fcb2134..a45e03ec9f 100644 --- a/README.md +++ b/README.md @@ -243,6 +243,7 @@ A summary can be found in the [Model Zoo](https://mmpose.readthedocs.io/en/lates - [x] [Human3.6M](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/datasets.html#human3-6m-tpami-2014) \[[homepage](http://vision.imar.ro/human3.6m/description.php)\] (TPAMI'2014) - [x] [COCO](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/datasets.html#coco-eccv-2014) \[[homepage](http://cocodataset.org/)\] (ECCV'2014) - [x] [CMU Panoptic](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/datasets.html#cmu-panoptic-iccv-2015) \[[homepage](http://domedb.perception.cs.cmu.edu/)\] (ICCV'2015) +- [x] [300VW](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/datasets.html#300w-imavis-2016) \[[homepage](https://ibug.doc.ic.ac.uk/resources/300-VW/)\] (ICCV'2015) - [x] [DeepFashion](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/datasets.html#deepfashion-cvpr-2016) \[[homepage](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/LandmarkDetection.html)\] (CVPR'2016) - [x] [300W](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/datasets.html#300w-imavis-2016) \[[homepage](https://ibug.doc.ic.ac.uk/resources/300-W/)\] (IMAVIS'2016) - [x] [RHD](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/datasets.html#rhd-iccv-2017) \[[homepage](https://lmb.informatik.uni-freiburg.de/resources/datasets/RenderedHandposeDataset.en.html)\] (ICCV'2017) diff --git a/configs/_base_/datasets/300vw.py b/configs/_base_/datasets/300vw.py new file mode 100644 index 0000000000..f75d6ec922 --- /dev/null +++ b/configs/_base_/datasets/300vw.py @@ -0,0 +1,134 @@ +dataset_info = dict( + dataset_name='300vw', + paper_info=dict( + author='Jie Shen, Stefanos Zafeiriou, Grigorios G. Chrysos, ' + 'Jean Kossaifi, Georgios Tzimiropoulos, Maja Pantic', + title='The First Facial Landmark Tracking in-the-Wild Challenge: ' + 'Benchmark and Results', + container='Proceedings of the IEEE ' + 'international conference on computer vision workshops', + year='2016', + homepage='https://ibug.doc.ic.ac.uk/resources/300-VW/', + ), + keypoint_info={ + 0: dict(name='kpt-0', id=0, color=[255, 0, 0], type='', swap='kpt-16'), + 1: dict(name='kpt-1', id=1, color=[255, 0, 0], type='', swap='kpt-15'), + 2: dict(name='kpt-2', id=2, color=[255, 0, 0], type='', swap='kpt-14'), + 3: dict(name='kpt-3', id=3, color=[255, 0, 0], type='', swap='kpt-13'), + 4: dict(name='kpt-4', id=4, color=[255, 0, 0], type='', swap='kpt-12'), + 5: dict(name='kpt-5', id=5, color=[255, 0, 0], type='', swap='kpt-11'), + 6: dict(name='kpt-6', id=6, color=[255, 0, 0], type='', swap='kpt-10'), + 7: dict(name='kpt-7', id=7, color=[255, 0, 0], type='', swap='kpt-9'), + 8: dict(name='kpt-8', id=8, color=[255, 0, 0], type='', swap=''), + 9: dict(name='kpt-9', id=9, color=[255, 0, 0], type='', swap='kpt-7'), + 10: + dict(name='kpt-10', id=10, color=[255, 0, 0], type='', swap='kpt-6'), + 11: + dict(name='kpt-11', id=11, color=[255, 0, 0], type='', swap='kpt-5'), + 12: + dict(name='kpt-12', id=12, color=[255, 0, 0], type='', swap='kpt-4'), + 13: + dict(name='kpt-13', id=13, color=[255, 0, 0], type='', swap='kpt-3'), + 14: + dict(name='kpt-14', id=14, color=[255, 0, 0], type='', swap='kpt-2'), + 15: + dict(name='kpt-15', id=15, color=[255, 0, 0], type='', swap='kpt-1'), + 16: + dict(name='kpt-16', id=16, color=[255, 0, 0], type='', swap='kpt-0'), + 17: + dict(name='kpt-17', id=17, color=[255, 0, 0], type='', swap='kpt-26'), + 18: + dict(name='kpt-18', id=18, color=[255, 0, 0], type='', swap='kpt-25'), + 19: + dict(name='kpt-19', id=19, color=[255, 0, 0], type='', swap='kpt-24'), + 20: + dict(name='kpt-20', id=20, color=[255, 0, 0], type='', swap='kpt-23'), + 21: + dict(name='kpt-21', id=21, color=[255, 0, 0], type='', swap='kpt-22'), + 22: + dict(name='kpt-22', id=22, color=[255, 0, 0], type='', swap='kpt-21'), + 23: + dict(name='kpt-23', id=23, color=[255, 0, 0], type='', swap='kpt-20'), + 24: + dict(name='kpt-24', id=24, color=[255, 0, 0], type='', swap='kpt-19'), + 25: + dict(name='kpt-25', id=25, color=[255, 0, 0], type='', swap='kpt-18'), + 26: + dict(name='kpt-26', id=26, color=[255, 0, 0], type='', swap='kpt-17'), + 27: dict(name='kpt-27', id=27, color=[255, 0, 0], type='', swap=''), + 28: dict(name='kpt-28', id=28, color=[255, 0, 0], type='', swap=''), + 29: dict(name='kpt-29', id=29, color=[255, 0, 0], type='', swap=''), + 30: dict(name='kpt-30', id=30, color=[255, 0, 0], type='', swap=''), + 31: + dict(name='kpt-31', id=31, color=[255, 0, 0], type='', swap='kpt-35'), + 32: + dict(name='kpt-32', id=32, color=[255, 0, 0], type='', swap='kpt-34'), + 33: dict(name='kpt-33', id=33, color=[255, 0, 0], type='', swap=''), + 34: + dict(name='kpt-34', id=34, color=[255, 0, 0], type='', swap='kpt-32'), + 35: + dict(name='kpt-35', id=35, color=[255, 0, 0], type='', swap='kpt-31'), + 36: + dict(name='kpt-36', id=36, color=[255, 0, 0], type='', swap='kpt-45'), + 37: + dict(name='kpt-37', id=37, color=[255, 0, 0], type='', swap='kpt-44'), + 38: + dict(name='kpt-38', id=38, color=[255, 0, 0], type='', swap='kpt-43'), + 39: + dict(name='kpt-39', id=39, color=[255, 0, 0], type='', swap='kpt-42'), + 40: + dict(name='kpt-40', id=40, color=[255, 0, 0], type='', swap='kpt-47'), + 41: dict( + name='kpt-41', id=41, color=[255, 0, 0], type='', swap='kpt-46'), + 42: dict( + name='kpt-42', id=42, color=[255, 0, 0], type='', swap='kpt-39'), + 43: dict( + name='kpt-43', id=43, color=[255, 0, 0], type='', swap='kpt-38'), + 44: dict( + name='kpt-44', id=44, color=[255, 0, 0], type='', swap='kpt-37'), + 45: dict( + name='kpt-45', id=45, color=[255, 0, 0], type='', swap='kpt-36'), + 46: dict( + name='kpt-46', id=46, color=[255, 0, 0], type='', swap='kpt-41'), + 47: dict( + name='kpt-47', id=47, color=[255, 0, 0], type='', swap='kpt-40'), + 48: dict( + name='kpt-48', id=48, color=[255, 0, 0], type='', swap='kpt-54'), + 49: dict( + name='kpt-49', id=49, color=[255, 0, 0], type='', swap='kpt-53'), + 50: dict( + name='kpt-50', id=50, color=[255, 0, 0], type='', swap='kpt-52'), + 51: dict(name='kpt-51', id=51, color=[255, 0, 0], type='', swap=''), + 52: dict( + name='kpt-52', id=52, color=[255, 0, 0], type='', swap='kpt-50'), + 53: dict( + name='kpt-53', id=53, color=[255, 0, 0], type='', swap='kpt-49'), + 54: dict( + name='kpt-54', id=54, color=[255, 0, 0], type='', swap='kpt-48'), + 55: dict( + name='kpt-55', id=55, color=[255, 0, 0], type='', swap='kpt-59'), + 56: dict( + name='kpt-56', id=56, color=[255, 0, 0], type='', swap='kpt-58'), + 57: dict(name='kpt-57', id=57, color=[255, 0, 0], type='', swap=''), + 58: dict( + name='kpt-58', id=58, color=[255, 0, 0], type='', swap='kpt-56'), + 59: dict( + name='kpt-59', id=59, color=[255, 0, 0], type='', swap='kpt-55'), + 60: dict( + name='kpt-60', id=60, color=[255, 0, 0], type='', swap='kpt-64'), + 61: dict( + name='kpt-61', id=61, color=[255, 0, 0], type='', swap='kpt-63'), + 62: dict(name='kpt-62', id=62, color=[255, 0, 0], type='', swap=''), + 63: dict( + name='kpt-63', id=63, color=[255, 0, 0], type='', swap='kpt-61'), + 64: dict( + name='kpt-64', id=64, color=[255, 0, 0], type='', swap='kpt-60'), + 65: dict( + name='kpt-65', id=65, color=[255, 0, 0], type='', swap='kpt-67'), + 66: dict(name='kpt-66', id=66, color=[255, 0, 0], type='', swap=''), + 67: dict( + name='kpt-67', id=67, color=[255, 0, 0], type='', swap='kpt-65'), + }, + skeleton_info={}, + joint_weights=[1.] * 68, + sigmas=[]) diff --git a/docs/en/dataset_zoo/2d_face_keypoint.md b/docs/en/dataset_zoo/2d_face_keypoint.md index f0861bd2c3..bc3463db7e 100644 --- a/docs/en/dataset_zoo/2d_face_keypoint.md +++ b/docs/en/dataset_zoo/2d_face_keypoint.md @@ -6,6 +6,7 @@ If your folder structure is different, you may need to change the corresponding MMPose supported datasets: - [300W](#300w-dataset) \[ [Homepage](https://ibug.doc.ic.ac.uk/resources/300-W/) \] +- [300VW](#300vw-dataset) \[ [Homepage](https://ibug.doc.ic.ac.uk/resources/300-VW/) \] - [WFLW](#wflw-dataset) \[ [Homepage](https://wywu.github.io/projects/LAB/WFLW.html) \] - [AFLW](#aflw-dataset) \[ [Homepage](https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/aflw/) \] - [COFW](#cofw-dataset) \[ [Homepage](http://www.vision.caltech.edu/xpburgos/ICCV13/) \] @@ -94,6 +95,60 @@ mmpose ... ``` +## 300VW Dataset + + + +
+300VW (ICCVW'2015) + +```bibtex +@inproceedings{shen2015first, + title={The first facial landmark tracking in-the-wild challenge: Benchmark and results}, + author={Shen, Jie and Zafeiriou, Stefanos and Chrysos, Grigoris G and Kossaifi, Jean and Tzimiropoulos, Georgios and Pantic, Maja}, + booktitle={Proceedings of the IEEE international conference on computer vision workshops}, + pages={50--58}, + year={2015} +} +``` + +
+ +For 300VW data, please register and download images from [300VW Dataset](https://ibug.doc.ic.ac.uk/download/300VW_Dataset_2015_12_14.zip/). +Unzip and use the "tools/dataset_converters/300vw2coco.py" to process the data. + + + +Put the 300VW under {MMPose}/data, and make them look like this: + +```text +mmpose +├── mmpose +├── docs +├── tests +├── tools +├── configs +`── data + │── 300vw + |── annotations + | |── train.json + | |── test_1.json + | |── test_2.json + | `── test_3.json + `── images + |── 001 + | `── imgs + | |── 000001.png + | |── 000002.png + | ... + |── 002 + | `── imgs + | |── 000001.png + | |── 000002.png + | ... + | ... +``` + ## WFLW Dataset diff --git a/docs/src/papers/datasets/300vw.md b/docs/src/papers/datasets/300vw.md new file mode 100644 index 0000000000..53859f6da5 --- /dev/null +++ b/docs/src/papers/datasets/300vw.md @@ -0,0 +1,18 @@ +# 300 faces in-the-wild challenge: Database and results + + + +
+300VW (ICCVW'2015) + +```bibtex +@inproceedings{shen2015first, + title={The first facial landmark tracking in-the-wild challenge: Benchmark and results}, + author={Shen, Jie and Zafeiriou, Stefanos and Chrysos, Grigoris G and Kossaifi, Jean and Tzimiropoulos, Georgios and Pantic, Maja}, + booktitle={Proceedings of the IEEE international conference on computer vision workshops}, + pages={50--58}, + year={2015} +} +``` + +
diff --git a/docs/zh_cn/dataset_zoo/2d_face_keypoint.md b/docs/zh_cn/dataset_zoo/2d_face_keypoint.md index f0861bd2c3..da6cf1b0e0 100644 --- a/docs/zh_cn/dataset_zoo/2d_face_keypoint.md +++ b/docs/zh_cn/dataset_zoo/2d_face_keypoint.md @@ -6,6 +6,7 @@ If your folder structure is different, you may need to change the corresponding MMPose supported datasets: - [300W](#300w-dataset) \[ [Homepage](https://ibug.doc.ic.ac.uk/resources/300-W/) \] +- [300VW](#300vw-dataset) \[ [Homepage](https://ibug.doc.ic.ac.uk/resources/300-VW/) \] - [WFLW](#wflw-dataset) \[ [Homepage](https://wywu.github.io/projects/LAB/WFLW.html) \] - [AFLW](#aflw-dataset) \[ [Homepage](https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/aflw/) \] - [COFW](#cofw-dataset) \[ [Homepage](http://www.vision.caltech.edu/xpburgos/ICCV13/) \] @@ -94,6 +95,61 @@ mmpose ... ``` +## 300VW Dataset + + + +
+300VW (ICCVW'2015) + +```bibtex +@inproceedings{shen2015first, + title={The first facial landmark tracking in-the-wild challenge: Benchmark and results}, + author={Shen, Jie and Zafeiriou, Stefanos and Chrysos, Grigoris G and Kossaifi, Jean and Tzimiropoulos, Georgios and Pantic, Maja}, + booktitle={Proceedings of the IEEE international conference on computer vision workshops}, + pages={50--58}, + year={2015} +} +``` + +
+ +300VW dataset follows the same mark-up (i.e. set of facial landmarks) used in the 300W. +For 300VW data, please register and download images from [300VW Dataset](https://ibug.doc.ic.ac.uk/download/300VW_Dataset_2015_12_14.zip) . +Unzip and use the "tools/dataset_converters/300vw2coco.py" to process the data. + + + +Put the 300VW under {MMPose}/data, and make them look like this: + +```text +mmpose +├── mmpose +├── docs +├── tests +├── tools +├── configs +`── data + │── 300vw + |── annotations + | |── train.json + | |── test_1.json + | |── test_2.json + | `── test_3.json + `── images + |── 001 + | `── imgs + | |── 000001.png + | |── 000002.png + | ... + |── 002 + | `── imgs + | |── 000001.png + | |── 000002.png + | ... + | ... +``` + ## WFLW Dataset diff --git a/mmpose/datasets/datasets/face/__init__.py b/mmpose/datasets/datasets/face/__init__.py index 1b78d87502..21c531f6bc 100644 --- a/mmpose/datasets/datasets/face/__init__.py +++ b/mmpose/datasets/datasets/face/__init__.py @@ -2,6 +2,7 @@ from .aflw_dataset import AFLWDataset from .coco_wholebody_face_dataset import CocoWholeBodyFaceDataset from .cofw_dataset import COFWDataset +from .face_300vw_dataset import Face300VWDataset from .face_300w_dataset import Face300WDataset from .face_300wlp_dataset import Face300WLPDataset from .lapa_dataset import LapaDataset @@ -9,5 +10,6 @@ __all__ = [ 'Face300WDataset', 'WFLWDataset', 'AFLWDataset', 'COFWDataset', - 'CocoWholeBodyFaceDataset', 'LapaDataset', 'Face300WLPDataset' + 'CocoWholeBodyFaceDataset', 'LapaDataset', 'Face300WLPDataset', + 'Face300VWDataset' ] diff --git a/mmpose/datasets/datasets/face/face_300vw_dataset.py b/mmpose/datasets/datasets/face/face_300vw_dataset.py new file mode 100644 index 0000000000..35979c618f --- /dev/null +++ b/mmpose/datasets/datasets/face/face_300vw_dataset.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Optional + +import numpy as np + +from mmpose.registry import DATASETS +from mmpose.structures.bbox import bbox_cs2xyxy +from ..base import BaseCocoStyleDataset + + +@DATASETS.register_module() +class Face300VWDataset(BaseCocoStyleDataset): + """300VW dataset for face keypoint tracking. + + "The First Facial Landmark Tracking in-the-Wild Challenge: + Benchmark and Results", + Proceedings of the IEEE + international conference on computer vision workshops. + + The landmark annotations follow the 68 points mark-up. The definition + can be found in `https://ibug.doc.ic.ac.uk/resources/300-VW/`. + + Args: + ann_file (str): Annotation file path. Default: ''. + bbox_file (str, optional): Detection result file path. If + ``bbox_file`` is set, detected bboxes loaded from this file will + be used instead of ground-truth bboxes. This setting is only for + evaluation, i.e., ignored when ``test_mode`` is ``False``. + Default: ``None``. + data_mode (str): Specifies the mode of data samples: ``'topdown'`` or + ``'bottomup'``. In ``'topdown'`` mode, each data sample contains + one instance; while in ``'bottomup'`` mode, each data sample + contains all instances in a image. Default: ``'topdown'`` + metainfo (dict, optional): Meta information for dataset, such as class + information. Default: ``None``. + data_root (str, optional): The root directory for ``data_prefix`` and + ``ann_file``. Default: ``None``. + data_prefix (dict, optional): Prefix for training data. Default: + ``dict(img=None, ann=None)``. + filter_cfg (dict, optional): Config for filter data. Default: `None`. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Default: ``None`` which means using all ``data_infos``. + serialize_data (bool, optional): Whether to hold memory using + serialized objects, when enabled, data loader workers can use + shared RAM from master process instead of making a copy. + Default: ``True``. + pipeline (list, optional): Processing pipeline. Default: []. + test_mode (bool, optional): ``test_mode=True`` means in test phase. + Default: ``False``. + lazy_init (bool, optional): Whether to load annotation during + instantiation. In some cases, such as visualization, only the meta + information of the dataset is needed, which is not necessary to + load annotation file. ``Basedataset`` can skip load annotations to + save time by set ``lazy_init=False``. Default: ``False``. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Default: 1000. + """ + + METAINFO: dict = dict(from_file='configs/_base_/datasets/300vw.py') + + def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: + """Parse raw Face300VW annotation of an instance. + + Args: + raw_data_info (dict): Raw data information loaded from + ``ann_file``. It should have following contents: + + - ``'raw_ann_info'``: Raw annotation of an instance + - ``'raw_img_info'``: Raw information of the image that + contains the instance + + Returns: + dict: Parsed instance annotation + """ + + ann = raw_data_info['raw_ann_info'] + img = raw_data_info['raw_img_info'] + + img_path = osp.join(self.data_prefix['img'], img['file_name']) + + # 300vw bbox scales are normalized with factor 200. + pixel_std = 200. + + # center, scale in shape [1, 2] and bbox in [1, 4] + center = np.array([ann['center']], dtype=np.float32) + scale = np.array([[ann['scale'], ann['scale']]], + dtype=np.float32) * pixel_std + bbox = bbox_cs2xyxy(center, scale) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array( + ann['keypoints'], dtype=np.float32).reshape(1, -1, 3) + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2]) + + num_keypoints = ann['num_keypoints'] + + data_info = { + 'img_id': ann['image_id'], + 'img_path': img_path, + 'bbox': bbox, + 'bbox_center': center, + 'bbox_scale': scale, + 'bbox_score': np.ones(1, dtype=np.float32), + 'num_keypoints': num_keypoints, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'iscrowd': ann['iscrowd'], + 'id': ann['id'], + } + return data_info diff --git a/mmpose/visualization/fast_visualizer.py b/mmpose/visualization/fast_visualizer.py index fa0cb38527..fdac734575 100644 --- a/mmpose/visualization/fast_visualizer.py +++ b/mmpose/visualization/fast_visualizer.py @@ -1,5 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + import cv2 +import numpy as np + + +class Instances: + keypoints: List[List[Tuple[int, int]]] + keypoint_scores: List[List[float]] class FastVisualizer: @@ -9,7 +17,7 @@ class FastVisualizer: Args: metainfo (dict): pose meta information - radius (int, optional)): Keypoint radius for visualization. + radius (int, optional): Keypoint radius for visualization. Defaults to 6. line_width (int, optional): Link width for visualization. Defaults to 3. @@ -18,18 +26,23 @@ class FastVisualizer: Defaults to 0.3. """ - def __init__(self, metainfo, radius=6, line_width=3, kpt_thr=0.3): + def __init__(self, + metainfo: Dict, + radius: Optional[int] = 6, + line_width: Optional[int] = 3, + kpt_thr: Optional[float] = 0.3): self.radius = radius self.line_width = line_width self.kpt_thr = kpt_thr - self.keypoint_id2name = metainfo['keypoint_id2name'] - self.keypoint_name2id = metainfo['keypoint_name2id'] - self.keypoint_colors = metainfo['keypoint_colors'] - self.skeleton_links = metainfo['skeleton_links'] - self.skeleton_link_colors = metainfo['skeleton_link_colors'] + self.keypoint_id2name = metainfo.get('keypoint_id2name', None) + self.keypoint_name2id = metainfo.get('keypoint_name2id', None) + self.keypoint_colors = metainfo.get('keypoint_colors', + [(255, 255, 255)]) + self.skeleton_links = metainfo.get('skeleton_links', None) + self.skeleton_link_colors = metainfo.get('skeleton_link_colors', None) - def draw_pose(self, img, instances): + def draw_pose(self, img: np.ndarray, instances: Instances): """Draw pose estimations on the given image. This method draws keypoints and skeleton links on the input image @@ -76,3 +89,64 @@ def draw_pose(self, img, instances): color, -1) cv2.circle(img, (int(x_coord), int(y_coord)), self.radius, (255, 255, 255)) + + def draw_points(self, img: np.ndarray, instances: Union[Instances, Dict, + np.ndarray]): + """Draw points on the given image. + + This method draws keypoints on the input image + using the provided instances. + + Args: + img (numpy.ndarray): The input image on which to + draw the keypoints. + instances (object|dict|np.ndarray): + An object containing keypoints, + or a dict containing 'keypoints', + or a np.ndarray in shape of + (Instance_num, Point_num, Point_dim) + + Returns: + None: The input image will be modified in place. + """ + + if instances is None: + print('no instance detected') + return + + # support different types of keypoints inputs + if hasattr(instances, 'keypoints'): + keypoints = instances.keypoints + elif isinstance(instances, dict) and 'keypoints' in instances: + keypoints = instances['keypoints'] + elif isinstance(instances, np.ndarray): + shape = instances.shape + assert shape[-1] == 2, 'only support 2-dim point!' + if len(shape) == 2: + keypoints = instances[None] + elif len(shape) == 3: + pass + else: + raise ValueError('input keypoints should be in shape of' + '(Instance_num, Point_num, Point_dim)') + else: + raise ValueError('The keypoints should be:' + 'object containing keypoints,' + "or a dict containing 'keypoints'," + 'or a np.ndarray in shape of' + '(Instance_num, Point_num, Point_dim)') + + if len(self.keypoint_colors) < len(keypoints[0]): + repeat_num = len(keypoints[0]) - len(self.keypoint_colors) + self.keypoint_colors += [(255, 255, 255)] * repeat_num + self.keypoint_colors = np.array(self.keypoint_colors) + + for kpts in keypoints: + for kid, kpt in enumerate(kpts): + x_coord, y_coord = int(kpt[0]), int(kpt[1]) + + color = self.keypoint_colors[kid].tolist() + cv2.circle(img, (int(x_coord), int(y_coord)), self.radius, + color, -1) + cv2.circle(img, (int(x_coord), int(y_coord)), self.radius, + (255, 255, 255)) diff --git a/tests/data/300vw/001/annot/000006.pts b/tests/data/300vw/001/annot/000006.pts new file mode 100644 index 0000000000..3889aa8f93 --- /dev/null +++ b/tests/data/300vw/001/annot/000006.pts @@ -0,0 +1,72 @@ +version: 1 +n_points: 68 +{ +743.224 197.764 +744.067 213.767 +747.494 229.362 +752.659 245.009 +758.92 259.223 +767.901 272.356 +778.802 283.44 +791.215 292.702 +805.644 296.179 +822.777 294.269 +840.544 285.779 +855.834 273.961 +867.449 258.999 +873.996 241.046 +876.935 222.009 +878.567 202.398 +878.29 182.498 +744.007 167.503 +750.098 162.094 +758.537 160.906 +767.509 162.778 +775.861 167.055 +801.008 165.555 +812.685 160.407 +825.422 158.165 +838.176 160.521 +849.128 166.753 +789.759 186.16 +789.699 198.174 +789.273 210.165 +788.975 222.281 +782.603 232.478 +787.429 234.27 +792.678 235.337 +798.527 234.133 +804.664 232.902 +755.736 189.198 +761.237 184.449 +769.829 184.474 +778 190.301 +769.807 192.891 +761.326 193.176 +811.185 189.658 +818.224 183.243 +827.82 183.095 +836.647 187.198 +829.057 191.518 +819.507 191.948 +778.412 257.689 +782.738 250.599 +789.035 246.573 +794.615 248.185 +800.403 246.606 +811.204 250.488 +822.537 256.574 +813.083 264.925 +803.39 268.215 +796.966 268.73 +790.759 267.758 +784.142 264.276 +782.377 257.528 +789.969 255.358 +795.418 255.634 +801.363 255.421 +817.652 256.63 +801.613 256.589 +795.658 256.775 +790.094 256.132 +} diff --git a/tests/data/300vw/001/annot/000009.pts b/tests/data/300vw/001/annot/000009.pts new file mode 100644 index 0000000000..7435d19b43 --- /dev/null +++ b/tests/data/300vw/001/annot/000009.pts @@ -0,0 +1,72 @@ +version: 1 +n_points: 68 +{ +742.873 203.975 +743.86 219.009 +747.556 233.644 +752.945 248.605 +758.873 262.435 +767.402 275.713 +778.158 286.99 +790.906 296.386 +805.448 299.882 +823.054 297.715 +841.765 288.765 +857.714 276.621 +869.301 261.256 +875.322 243.141 +877.652 224.024 +878.624 204.199 +877.759 184.091 +739.839 173.49 +745.097 169.057 +752.507 168.519 +760.509 170.747 +768.073 174.906 +791.577 172.962 +803.164 167.369 +816.229 164.505 +829.601 166.297 +841.772 171.687 +781.866 193.207 +781.874 205.541 +781.423 217.922 +781.105 230.353 +777.568 240.329 +781.547 242.003 +786.186 242.803 +791.695 241.762 +797.794 240.677 +751.554 195.364 +756.158 190.45 +764.343 190.443 +772.432 196.54 +764.574 199.384 +756.473 199.727 +804.344 195.462 +810.661 188.458 +820.197 188.077 +829.408 192.078 +821.842 196.959 +812.389 197.654 +777.821 266.056 +779.868 259.244 +784.392 255.175 +789.378 256.665 +794.399 255.071 +805.22 258.831 +817.54 264.382 +807.933 271.958 +798.515 274.526 +792.885 274.904 +787.314 273.956 +782.133 270.936 +781.614 265.715 +785.822 264.012 +790.616 264.186 +795.84 263.973 +812.474 264.499 +796.13 262.77 +790.946 262.751 +786.04 262.145 +} diff --git a/tests/data/300vw/001/imgs/000006.jpg b/tests/data/300vw/001/imgs/000006.jpg new file mode 100644 index 0000000000..18bbc740a1 Binary files /dev/null and b/tests/data/300vw/001/imgs/000006.jpg differ diff --git a/tests/data/300vw/001/imgs/000009.jpg b/tests/data/300vw/001/imgs/000009.jpg new file mode 100644 index 0000000000..b9782a65e1 Binary files /dev/null and b/tests/data/300vw/001/imgs/000009.jpg differ diff --git a/tests/data/300vw/401/annot/000731.pts b/tests/data/300vw/401/annot/000731.pts new file mode 100644 index 0000000000..181faed42a --- /dev/null +++ b/tests/data/300vw/401/annot/000731.pts @@ -0,0 +1,72 @@ +version: 1.0 +n_points: 68 +{ +354.326810 277.170196 +356.862516 297.896319 +360.708857 321.402933 +366.772339 344.811744 +375.576753 367.589576 +388.959244 386.799754 +407.262945 401.811444 +428.068560 412.214259 +453.614796 416.691435 +471.226590 411.577691 +484.005875 396.976013 +494.537085 376.929648 +502.043375 358.298449 +508.351854 337.612143 +513.750512 315.400709 +517.398128 293.446696 +519.878097 269.587778 +401.071060 289.120463 +415.868158 283.175462 +431.633197 280.639818 +447.288182 280.860303 +462.502056 283.726681 +493.759549 279.708941 +501.023665 273.761758 +509.227751 268.874220 +517.254018 266.415103 +522.849079 267.392002 +483.202797 299.142724 +485.261679 313.765218 +487.619687 327.980627 +490.111994 342.128622 +462.545119 348.143214 +473.539324 350.954483 +484.012453 351.175083 +491.398925 348.749592 +495.698470 344.339799 +415.040415 304.069825 +427.740254 304.565914 +438.181828 303.487879 +449.618639 300.557420 +438.347956 305.781678 +427.813322 306.703437 +492.929038 293.082034 +501.907509 291.762309 +510.255618 289.897320 +516.708942 286.473615 +510.344380 291.584684 +501.907485 293.272054 +431.746485 373.402051 +449.048791 371.628950 +465.248833 370.765301 +471.581789 371.573768 +478.582506 369.215815 +484.976525 368.995428 +492.130309 368.389324 +485.662685 378.849941 +478.521143 384.643953 +470.975365 386.597756 +462.553999 386.328249 +448.270860 382.261998 +434.809022 374.014486 +464.103456 376.087702 +471.312425 376.020313 +478.521168 374.807624 +489.607029 369.877287 +478.184380 376.222430 +471.312350 377.165647 +463.901645 377.165623 +} diff --git a/tests/data/300vw/401/annot/000732.pts b/tests/data/300vw/401/annot/000732.pts new file mode 100644 index 0000000000..bf0dc33355 --- /dev/null +++ b/tests/data/300vw/401/annot/000732.pts @@ -0,0 +1,72 @@ +version: 1 +n_points: 68 +{ +363.817 271.708 +364.252 298.829 +365.968 325.877 +370.470 351.633 +381.256 375.027 +397.661 394.398 +417.463 409.304 +438.872 420.327 +460.669 420.571 +478.411 411.705 +487.951 392.441 +496.754 372.500 +504.338 353.575 +509.906 335.599 +514.506 317.927 +516.599 300.991 +515.792 284.064 +404.467 290.300 +418.546 282.394 +434.658 278.447 +450.617 278.359 +465.479 282.345 +496.947 282.828 +503.613 274.821 +511.716 269.283 +520.833 266.200 +527.099 269.220 +481.812 299.263 +483.596 314.370 +485.608 329.230 +487.438 344.276 +462.064 350.710 +471.954 354.403 +481.878 355.891 +489.689 352.322 +494.701 346.437 +419.526 304.128 +432.296 302.727 +442.685 299.918 +451.791 301.013 +442.576 304.957 +432.247 306.164 +493.153 294.574 +502.979 291.036 +511.389 289.049 +517.585 288.486 +511.946 293.834 +503.509 295.495 +434.752 372.539 +452.756 370.355 +467.045 368.214 +474.267 369.293 +480.141 366.373 +484.881 365.139 +488.812 365.241 +483.252 374.699 +478.597 382.170 +471.995 385.121 +464.377 385.874 +450.608 382.911 +440.029 373.013 +466.575 374.235 +473.917 373.978 +479.839 371.716 +484.497 367.569 +478.771 372.052 +472.793 375.438 +465.679 376.251 +} diff --git a/tests/data/300vw/401/imgs/000731.jpg b/tests/data/300vw/401/imgs/000731.jpg new file mode 100644 index 0000000000..d3f91e17d0 Binary files /dev/null and b/tests/data/300vw/401/imgs/000731.jpg differ diff --git a/tests/data/300vw/401/imgs/000732.jpg b/tests/data/300vw/401/imgs/000732.jpg new file mode 100644 index 0000000000..04b5bf7723 Binary files /dev/null and b/tests/data/300vw/401/imgs/000732.jpg differ diff --git a/tests/data/300vw/anno_300vw.json b/tests/data/300vw/anno_300vw.json new file mode 100644 index 0000000000..2178d27fe0 --- /dev/null +++ b/tests/data/300vw/anno_300vw.json @@ -0,0 +1,690 @@ +{ + "images": [ + { + "file_name": "001/imgs/000006.jpg", + "height": 720, + "width": 1280, + "id": 0 + }, + { + "file_name": "001/imgs/000009.jpg", + "height": 720, + "width": 1280, + "id": 1 + }, + { + "file_name": "401/imgs/000732.jpg", + "height": 512, + "width": 700, + "id": 2 + } + ], + "annotations": [ + { + "segmentation": [], + "num_keypoints": 68, + "iscrowd": 0, + "category_id": 1, + "keypoints": [ + 743.224, + 197.764, + 1, + 744.067, + 213.767, + 1, + 747.494, + 229.362, + 1, + 752.659, + 245.009, + 1, + 758.92, + 259.223, + 1, + 767.901, + 272.356, + 1, + 778.802, + 283.44, + 1, + 791.215, + 292.702, + 1, + 805.644, + 296.179, + 1, + 822.777, + 294.269, + 1, + 840.544, + 285.779, + 1, + 855.834, + 273.961, + 1, + 867.449, + 258.999, + 1, + 873.996, + 241.046, + 1, + 876.935, + 222.009, + 1, + 878.567, + 202.398, + 1, + 878.29, + 182.498, + 1, + 744.007, + 167.503, + 1, + 750.098, + 162.094, + 1, + 758.537, + 160.906, + 1, + 767.509, + 162.778, + 1, + 775.861, + 167.055, + 1, + 801.008, + 165.555, + 1, + 812.685, + 160.407, + 1, + 825.422, + 158.165, + 1, + 838.176, + 160.521, + 1, + 849.128, + 166.753, + 1, + 789.759, + 186.16, + 1, + 789.699, + 198.174, + 1, + 789.273, + 210.165, + 1, + 788.975, + 222.281, + 1, + 782.603, + 232.478, + 1, + 787.429, + 234.27, + 1, + 792.678, + 235.337, + 1, + 798.527, + 234.133, + 1, + 804.664, + 232.902, + 1, + 755.736, + 189.198, + 1, + 761.237, + 184.449, + 1, + 769.829, + 184.474, + 1, + 778.0, + 190.301, + 1, + 769.807, + 192.891, + 1, + 761.326, + 193.176, + 1, + 811.185, + 189.658, + 1, + 818.224, + 183.243, + 1, + 827.82, + 183.095, + 1, + 836.647, + 187.198, + 1, + 829.057, + 191.518, + 1, + 819.507, + 191.948, + 1, + 778.412, + 257.689, + 1, + 782.738, + 250.599, + 1, + 789.035, + 246.573, + 1, + 794.615, + 248.185, + 1, + 800.403, + 246.606, + 1, + 811.204, + 250.488, + 1, + 822.537, + 256.574, + 1, + 813.083, + 264.925, + 1, + 803.39, + 268.215, + 1, + 796.966, + 268.73, + 1, + 790.759, + 267.758, + 1, + 784.142, + 264.276, + 1, + 782.377, + 257.528, + 1, + 789.969, + 255.358, + 1, + 795.418, + 255.634, + 1, + 801.363, + 255.421, + 1, + 817.652, + 256.63, + 1, + 801.613, + 256.589, + 1, + 795.658, + 256.775, + 1, + 790.094, + 256.132, + 1 + ], + "scale": 0.695, + "area": 18679.22880199999, + "center": [ + 810.8955000000001, + 227.17199999999997 + ], + "image_id": 0, + "id": 0 + }, + { + "segmentation": [], + "num_keypoints": 68, + "iscrowd": 0, + "category_id": 1, + "keypoints": [ + 742.873, + 203.975, + 1, + 743.86, + 219.009, + 1, + 747.556, + 233.644, + 1, + 752.945, + 248.605, + 1, + 758.873, + 262.435, + 1, + 767.402, + 275.713, + 1, + 778.158, + 286.99, + 1, + 790.906, + 296.386, + 1, + 805.448, + 299.882, + 1, + 823.054, + 297.715, + 1, + 841.765, + 288.765, + 1, + 857.714, + 276.621, + 1, + 869.301, + 261.256, + 1, + 875.322, + 243.141, + 1, + 877.652, + 224.024, + 1, + 878.624, + 204.199, + 1, + 877.759, + 184.091, + 1, + 739.839, + 173.49, + 1, + 745.097, + 169.057, + 1, + 752.507, + 168.519, + 1, + 760.509, + 170.747, + 1, + 768.073, + 174.906, + 1, + 791.577, + 172.962, + 1, + 803.164, + 167.369, + 1, + 816.229, + 164.505, + 1, + 829.601, + 166.297, + 1, + 841.772, + 171.687, + 1, + 781.866, + 193.207, + 1, + 781.874, + 205.541, + 1, + 781.423, + 217.922, + 1, + 781.105, + 230.353, + 1, + 777.568, + 240.329, + 1, + 781.547, + 242.003, + 1, + 786.186, + 242.803, + 1, + 791.695, + 241.762, + 1, + 797.794, + 240.677, + 1, + 751.554, + 195.364, + 1, + 756.158, + 190.45, + 1, + 764.343, + 190.443, + 1, + 772.432, + 196.54, + 1, + 764.574, + 199.384, + 1, + 756.473, + 199.727, + 1, + 804.344, + 195.462, + 1, + 810.661, + 188.458, + 1, + 820.197, + 188.077, + 1, + 829.408, + 192.078, + 1, + 821.842, + 196.959, + 1, + 812.389, + 197.654, + 1, + 777.821, + 266.056, + 1, + 779.868, + 259.244, + 1, + 784.392, + 255.175, + 1, + 789.378, + 256.665, + 1, + 794.399, + 255.071, + 1, + 805.22, + 258.831, + 1, + 817.54, + 264.382, + 1, + 807.933, + 271.958, + 1, + 798.515, + 274.526, + 1, + 792.885, + 274.904, + 1, + 787.314, + 273.956, + 1, + 782.133, + 270.936, + 1, + 781.614, + 265.715, + 1, + 785.822, + 264.012, + 1, + 790.616, + 264.186, + 1, + 795.84, + 263.973, + 1, + 812.474, + 264.499, + 1, + 796.13, + 262.77, + 1, + 790.946, + 262.751, + 1, + 786.04, + 262.145, + 1 + ], + "scale": 0.695, + "area": 18788.296945, + "center": [ + 809.2315000000001, + 232.1935 + ], + "image_id": 1, + "id": 1 + }, + { + "segmentation": [], + "num_keypoints": 68, + "iscrowd": 0, + "category_id": 1, + "keypoints": [ + 363.817, + 271.708, + 1, + 364.252, + 298.829, + 1, + 365.968, + 325.877, + 1, + 370.47, + 351.633, + 1, + 381.256, + 375.027, + 1, + 397.661, + 394.398, + 1, + 417.463, + 409.304, + 1, + 438.872, + 420.327, + 1, + 460.669, + 420.571, + 1, + 478.411, + 411.705, + 1, + 487.951, + 392.441, + 1, + 496.754, + 372.5, + 1, + 504.338, + 353.575, + 1, + 509.906, + 335.599, + 1, + 514.506, + 317.927, + 1, + 516.599, + 300.991, + 1, + 515.792, + 284.064, + 1, + 404.467, + 290.3, + 1, + 418.546, + 282.394, + 1, + 434.658, + 278.447, + 1, + 450.617, + 278.359, + 1, + 465.479, + 282.345, + 1, + 496.947, + 282.828, + 1, + 503.613, + 274.821, + 1, + 511.716, + 269.283, + 1, + 520.833, + 266.2, + 1, + 527.099, + 269.22, + 1, + 481.812, + 299.263, + 1, + 483.596, + 314.37, + 1, + 485.608, + 329.23, + 1, + 487.438, + 344.276, + 1, + 462.064, + 350.71, + 1, + 471.954, + 354.403, + 1, + 481.878, + 355.891, + 1, + 489.689, + 352.322, + 1, + 494.701, + 346.437, + 1, + 419.526, + 304.128, + 1, + 432.296, + 302.727, + 1, + 442.685, + 299.918, + 1, + 451.791, + 301.013, + 1, + 442.576, + 304.957, + 1, + 432.247, + 306.164, + 1, + 493.153, + 294.574, + 1, + 502.979, + 291.036, + 1, + 511.389, + 289.049, + 1, + 517.585, + 288.486, + 1, + 511.946, + 293.834, + 1, + 503.509, + 295.495, + 1, + 434.752, + 372.539, + 1, + 452.756, + 370.355, + 1, + 467.045, + 368.214, + 1, + 474.267, + 369.293, + 1, + 480.141, + 366.373, + 1, + 484.881, + 365.139, + 1, + 488.812, + 365.241, + 1, + 483.252, + 374.699, + 1, + 478.597, + 382.17, + 1, + 471.995, + 385.121, + 1, + 464.377, + 385.874, + 1, + 450.608, + 382.911, + 1, + 440.029, + 373.013, + 1, + 466.575, + 374.235, + 1, + 473.917, + 373.978, + 1, + 479.839, + 371.716, + 1, + 484.497, + 367.569, + 1, + 478.771, + 372.052, + 1, + 472.793, + 375.438, + 1, + 465.679, + 376.251, + 1 + ], + "scale": 0.82, + "area": 25206.00562200001, + "center": [ + 445.458, + 343.3855 + ], + "image_id": 2, + "id": 2 + } + ], + "categories": [ + { + "id": 1, + "name": "person" + } + ] +} \ No newline at end of file diff --git a/tests/data/300vw/broken_frames.npy b/tests/data/300vw/broken_frames.npy new file mode 100644 index 0000000000..e698d6d3c5 Binary files /dev/null and b/tests/data/300vw/broken_frames.npy differ diff --git a/tests/test_datasets/test_datasets/test_face_datasets/test_face_300vw_dataset.py b/tests/test_datasets/test_datasets/test_face_datasets/test_face_300vw_dataset.py new file mode 100644 index 0000000000..6e124b0496 --- /dev/null +++ b/tests/test_datasets/test_datasets/test_face_datasets/test_face_300vw_dataset.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np + +from mmpose.datasets.datasets.face import Face300VWDataset + + +class TestFace300VWDataset(TestCase): + + def build_face_300vw_dataset(self, **kwargs): + + cfg = dict( + ann_file='anno_300vw.json', + bbox_file=None, + data_mode='topdown', + data_root='tests/data/300vw', + pipeline=[], + test_mode=False) + + cfg.update(kwargs) + return Face300VWDataset(**cfg) + + def check_data_info_keys(self, + data_info: dict, + data_mode: str = 'topdown'): + if data_mode == 'topdown': + expected_keys = dict( + img_id=int, + img_path=str, + bbox_center=np.ndarray, + bbox_scale=np.ndarray, + bbox_score=np.ndarray, + keypoints=np.ndarray, + keypoints_visible=np.ndarray, + id=int) + elif data_mode == 'bottomup': + expected_keys = dict( + img_id=int, + img_path=str, + bbox_center=np.ndarray, + bbox_scale=np.ndarray, + bbox_score=np.ndarray, + keypoints=np.ndarray, + keypoints_visible=np.ndarray, + invalid_segs=list, + id=list) + else: + raise ValueError(f'Invalid data_mode {data_mode}') + + for key, type_ in expected_keys.items(): + self.assertIn(key, data_info) + self.assertIsInstance(data_info[key], type_, key) + + def check_metainfo_keys(self, metainfo: dict): + expected_keys = dict( + dataset_name=str, + num_keypoints=int, + keypoint_id2name=dict, + keypoint_name2id=dict, + upper_body_ids=list, + lower_body_ids=list, + flip_indices=list, + flip_pairs=list, + keypoint_colors=np.ndarray, + num_skeleton_links=int, + skeleton_links=list, + skeleton_link_colors=np.ndarray, + dataset_keypoint_weights=np.ndarray) + + for key, type_ in expected_keys.items(): + self.assertIn(key, metainfo) + self.assertIsInstance(metainfo[key], type_, key) + + def test_metainfo(self): + dataset = self.build_face_300vw_dataset() + self.check_metainfo_keys(dataset.metainfo) + # test dataset_name + self.assertEqual(dataset.metainfo['dataset_name'], '300vw') + + # test number of keypoints + num_keypoints = 68 + self.assertEqual(dataset.metainfo['num_keypoints'], num_keypoints) + self.assertEqual( + len(dataset.metainfo['keypoint_colors']), num_keypoints) + self.assertEqual( + len(dataset.metainfo['dataset_keypoint_weights']), num_keypoints) + + def test_topdown(self): + # test topdown training + dataset = self.build_face_300vw_dataset(data_mode='topdown') + self.assertEqual(dataset.data_mode, 'topdown') + self.assertEqual(dataset.bbox_file, None) + self.assertEqual(len(dataset), 3) + self.check_data_info_keys(dataset[0]) + + # test topdown testing + dataset = self.build_face_300vw_dataset( + data_mode='topdown', test_mode=True) + self.assertEqual(dataset.data_mode, 'topdown') + self.assertEqual(dataset.bbox_file, None) + self.assertEqual(len(dataset), 3) + self.check_data_info_keys(dataset[0]) + + def test_bottomup(self): + # test bottomup training + dataset = self.build_face_300vw_dataset(data_mode='bottomup') + self.assertEqual(len(dataset), 3) + self.check_data_info_keys(dataset[0], data_mode='bottomup') + + # test bottomup testing + dataset = self.build_face_300vw_dataset( + data_mode='bottomup', test_mode=True) + self.assertEqual(len(dataset), 3) + self.check_data_info_keys(dataset[0], data_mode='bottomup') + + def test_exceptions_and_warnings(self): + + with self.assertRaisesRegex(ValueError, 'got invalid data_mode'): + _ = self.build_face_300vw_dataset(data_mode='invalid') + + with self.assertRaisesRegex( + ValueError, + '"bbox_file" is only supported when `test_mode==True`'): + _ = self.build_face_300vw_dataset( + data_mode='topdown', + test_mode=False, + bbox_file='temp_bbox_file.json') + + with self.assertRaisesRegex( + ValueError, '"bbox_file" is only supported in topdown mode'): + _ = self.build_face_300vw_dataset( + data_mode='bottomup', + test_mode=True, + bbox_file='temp_bbox_file.json') + + with self.assertRaisesRegex( + ValueError, + '"bbox_score_thr" is only supported in topdown mode'): + _ = self.build_face_300vw_dataset( + data_mode='bottomup', + test_mode=True, + filter_cfg=dict(bbox_score_thr=0.3)) diff --git a/tools/dataset_converters/300vw2coco.py b/tools/dataset_converters/300vw2coco.py new file mode 100644 index 0000000000..8be0083266 --- /dev/null +++ b/tools/dataset_converters/300vw2coco.py @@ -0,0 +1,254 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import math +import multiprocessing +import os +import subprocess +from glob import glob +from os.path import join + +import numpy as np +from PIL import Image + + +def extract_frames(video_path): + # Get the base path and video name + base_path, video_name = os.path.split(video_path) + # Remove the extension from the video name to get the folder name + folder_name = 'imgs' + # Create the new folder path + folder_path = os.path.join(base_path, folder_name) + + if not os.path.exists(folder_path): + # Create the folder if it doesn't exist; + os.makedirs(folder_path) + + # Create the output file pattern + output_pattern = os.path.join(folder_path, '%06d.png') + + # Call ffmpeg to extract the frames + subprocess.call([ + 'ffmpeg', '-i', video_path, '-q:v', '0', '-start_number', '1', + output_pattern + ]) + else: + # Skip this video if the frame folder already exist! + print(f'{folder_path} already exist. Skip {video_path}!') + return + + +class Base300VW: + + def __init__(self): + extra_path = './tests/data/300vw/broken_frames.npy' + self.broken_frames = np.load(extra_path, allow_pickle=True).item() + self.videos_full = [ + '001', '002', '003', '004', '007', '009', '010', '011', '013', + '015', '016', '017', '018', '019', '020', '022', '025', '027', + '028', '029', '031', '033', '034', '035', '037', '039', '041', + '043', '044', '046', '047', '048', '049', '053', '057', '059', + '112', '113', '114', '115', '119', '120', '123', '124', '125', + '126', '138', '143', '144', '150', '158', '160', '203', '204', + '205', '208', '211', '212', '213', '214', '218', '223', '224', + '225', '401', '402', '403', '404', '405', '406', '407', '408', + '409', '410', '411', '412', '505', '506', '507', '508', '509', + '510', '511', '514', '515', '516', '517', '518', '519', '520', + '521', '522', '524', '525', '526', '528', '529', '530', '531', + '533', '537', '538', '540', '541', '546', '547', '548', '550', + '551', '553', '557', '558', '559', '562' + ] + + # Category 1 in laboratory and naturalistic well-lit conditions + self.videos_test_1 = [ + '114', '124', '125', '126', '150', '158', '401', '402', '505', + '506', '507', '508', '509', '510', '511', '514', '515', '518', + '519', '520', '521', '522', '524', '525', '537', '538', '540', + '541', '546', '547', '548' + ] + # Category 2 in real-world human-computer interaction applications + self.videos_test_2 = [ + '203', '208', '211', '212', '213', '214', '218', '224', '403', + '404', '405', '406', '407', '408', '409', '412', '550', '551', + '553' + ] + # Category 3 in arbitrary conditions + self.videos_test_3 = [ + '410', '411', '516', '517', '526', '528', '529', '530', '531', + '533', '557', '558', '559', '562' + ] + + self.videos_test = \ + self.videos_test_1 + self.videos_test_2 + self.videos_test_3 + self.videos_train = [ + i for i in self.videos_full if i not in self.videos_test + ] + + self.videos_part = ['001', '401'] + + self.point_num = 68 + + +class Preprocess300VW(Base300VW): + + def __init__(self, dataset_root): + super().__init__() + self.dataset_root = dataset_root + self._extract_frames() + self.json_data = self._init_json_data() + + def _init_json_data(self): + """Initialize JSON data structure.""" + return { + 'images': [], + 'annotations': [], + 'categories': [{ + 'id': 1, + 'name': 'person' + }] + } + + def _extract_frames(self): + """Extract frames from videos.""" + all_video_paths = glob(os.path.join(self.dataset_root, '*/vid.avi')) + with multiprocessing.Pool() as pool: + pool.map(extract_frames, all_video_paths) + + def _extract_keypoints_from_pts(self, file_path): + """Extract keypoints from .pts files.""" + keypoints = [] + with open(file_path, 'r') as file: + file_content = file.read() + start_index = file_content.find('{') + end_index = file_content.rfind('}') + if start_index != -1 and end_index != -1: + data_inside_braces = file_content[start_index + 1:end_index] + lines = data_inside_braces.split('\n') + for line in lines: + if line.strip(): + x, y = map(float, line.split()) + keypoints.extend([x, y]) + else: + print('No data found inside braces.') + return keypoints + + def _get_video_list(self, video_list): + """Get video list based on input type.""" + if isinstance(video_list, list): + return self.videos_part + elif isinstance(video_list, str): + if hasattr(self, video_list): + return getattr(self, video_list) + else: + raise KeyError + elif video_list is None: + return self.videos_part + else: + raise ValueError + + def _process_image(self, img_path): + """Process image and return image dictionary.""" + image_dict = {} + image_dict['file_name'] = os.path.relpath(img_path, self.dataset_root) + image_pic = Image.open(img_path) + pic_width, pic_height = image_pic.size + image_dict['height'] = pic_height + image_dict['width'] = pic_width + image_pic.close() + return image_dict + + def _process_annotation(self, annot_path, image_id, anno_id): + """Process annotation and return annotation dictionary.""" + annotation = { + 'segmentation': [], + 'num_keypoints': self.point_num, + 'iscrowd': 0, + 'category_id': 1, + } + keypoints = self._extract_keypoints_from_pts(annot_path) + keypoints3 = [] + for kp_i in range(1, 68 * 2 + 1): + keypoints3.append(keypoints[kp_i - 1]) + if kp_i % 2 == 0: + keypoints3.append(1) + annotation['keypoints'] = keypoints3 + annotation = self._calculate_annotation_properties( + annotation, keypoints) + annotation['image_id'] = image_id + annotation['id'] = anno_id + return annotation + + def _calculate_annotation_properties(self, annotation, keypoints): + """Calculate properties for annotation.""" + keypoints_x = [] + keypoints_y = [] + for j in range(self.point_num * 2): + if j % 2 == 0: + keypoints_x.append(keypoints[j]) + else: + keypoints_y.append(keypoints[j]) + x_left = min(keypoints_x) + x_right = max(keypoints_x) + y_low = min(keypoints_y) + y_high = max(keypoints_y) + w = x_right - x_left + h = y_high - y_low + annotation['scale'] = math.ceil(max(w, h)) / 200 + annotation['area'] = w * h + annotation['center'] = [(x_left + x_right) / 2, (y_low + y_high) / 2] + return annotation + + def convert_annotations(self, + video_list=None, + json_save_name='anno_300vw.json'): + """Convert 300vw original annotations to coco format.""" + video_list = self._get_video_list(video_list) + image_id = 0 + anno_id = 0 + for video_id in video_list: + annot_root = join(self.dataset_root, video_id, 'annot') + img_dir = join(self.dataset_root, video_id, 'imgs') + if not (os.path.isdir(annot_root) and os.path.isdir(img_dir)): + print(f'{annot_root} or {img_dir} not found. skip {video_id}!') + continue + annots = sorted(os.listdir(annot_root)) + for annot in annots: + frame_num = int(annot.split('.')[0]) + if video_id in self.broken_frames and \ + frame_num in self.broken_frames[video_id]: + print(f'skip broken frames: {frame_num} in {video_id}') + continue + + img_path = os.path.join(img_dir, f'{frame_num:06d}.png') + if not os.path.exists(img_path): + print(f'{img_path} not found. skip!') + continue + + # Process image and add to JSON data + image_dict = self._process_image(img_path) + image_dict['id'] = image_id + + # Construct annotation path + annot_path = os.path.join(annot_root, annot) + annotation = self._process_annotation(annot_path, image_id, + anno_id) + + # Process annotation and add to JSON data + self.json_data['images'].append(image_dict) + self.json_data['annotations'].append(annotation) + + image_id += 1 + anno_id += 1 + + print(f'Annotations from "{annot_root}" have been converted.') + + self._save_json_data(json_save_name) + + def _save_json_data(self, json_save_name): + json_save_path = os.path.join(self.dataset_root, json_save_name) + with open(json_save_path, 'w') as json_file: + json.dump(self.json_data, json_file, indent=4) + + +if __name__ == '__main__': + convert300vw = Preprocess300VW(dataset_root='./tests/data/300vw') + convert300vw.convert_annotations()