Skip to content

Commit

Permalink
merge BatchShapeDataPreprocessor into PoseDataPreprocessor
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis committed Sep 14, 2023
1 parent 502b1cd commit 9ecf22f
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 139 deletions.
2 changes: 1 addition & 1 deletion configs/body_2d_keypoint/edpose/coco/edpose_coco.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,5 @@ Results on COCO val2017

| Arch | BackBone | AP | AP<sup>50</sup> | AP<sup>75</sup> | AR | AR<sup>50</sup> | ckpt | log |
| :-------------------------------------------- | :-------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :--------------------------------------------: | :-------------------------------------------: |
| [edpose_res50_coco](/configs/body_2d_keypoint/edpose/coco/edpose_res50_coco.py) | ResNet-50 | 0.716 | 0.897 | 0.783 | 0.793 | 0.943 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.json) |
| [edpose_res50_coco](/configs/body_2d_keypoint/edpose/coco/edpose_res50_coco.py) | ResNet-50 | 0.716 | 0.898 | 0.783 | 0.793 | 0.944 | [ckpt](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.pth) | [log](https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.json) |
| | | | | | | | | |
4 changes: 2 additions & 2 deletions configs/body_2d_keypoint/edpose/coco/edpose_coco.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ Models:
- Dataset: COCO
Metrics:
AP: 0.716
AP@0.5: 0.897
AP@0.5: 0.898
AP@0.75: 0.783
AR: 0.793
AR@0.5: 0.943
AR@0.5: 0.944
Task: Body 2D Keypoint
Weights: https://download.openmmlab.com/mmpose/v1/body_2d_keypoint/edpose/coco/edpose_res50_coco_3rdparty.pth
2 changes: 1 addition & 1 deletion configs/body_2d_keypoint/edpose/coco/edpose_res50_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
model = dict(
type='BottomupPoseEstimator',
data_preprocessor=dict(
type='BatchShapeDataPreprocessor',
type='PoseDataPreprocessor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
bgr_to_rgb=True,
Expand Down
18 changes: 12 additions & 6 deletions mmpose/codecs/edpose_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class EDPoseLabel(BaseKeypointCodec):
"""

auxiliary_encode_keys = {'area', 'bboxes', 'img_shape'}
instance_mapping_table = dict(
bbox='bboxes',
keypoints='keypoints',
keypoints_visible='keypoints_visible',
area='areas',
)

def __init__(self, num_select: int = 100, num_keypoints: int = 17):
super().__init__()
Expand Down Expand Up @@ -81,18 +87,18 @@ def encode(

if bboxes is not None:
bboxes = np.concatenate(bbox_xyxy2cs(bboxes), axis=-1)
bboxes_labels = bboxes / np.array([w, h, w, h], dtype=np.float32)
bboxes = bboxes / np.array([w, h, w, h], dtype=np.float32)

if area is not None:
area_labels = area / float(w * h)
area = area / float(w * h)

if keypoints is not None:
keypoint_labels = keypoints / np.array([w, h], dtype=np.float32)
keypoints = keypoints / np.array([w, h], dtype=np.float32)

encoded = dict(
keypoint_labels=keypoint_labels,
area_labels=area_labels,
bboxes_labels=bboxes_labels,
keypoints=keypoints,
area=area,
bbox=bboxes,
keypoints_visible=keypoints_visible)

return encoded
Expand Down
6 changes: 3 additions & 3 deletions mmpose/models/data_preprocessors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .batch_augmentation import BatchSyncRandomResize
from .data_preprocessor import BatchShapeDataPreprocessor, PoseDataPreprocessor
from .data_preprocessor import PoseDataPreprocessor

__all__ = [
'PoseDataPreprocessor', 'BatchSyncRandomResize',
'BatchShapeDataPreprocessor'
'PoseDataPreprocessor',
'BatchSyncRandomResize',
]
234 changes: 108 additions & 126 deletions mmpose/models/data_preprocessors/data_preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from numbers import Number
from typing import List, Optional, Sequence, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms.functional as tvF
from mmengine.model import ImgDataPreprocessor
from mmengine.model.utils import stack_batch
from mmengine.utils import is_seq_of
Expand All @@ -18,7 +17,48 @@

@MODELS.register_module()
class PoseDataPreprocessor(ImgDataPreprocessor):
"""Image pre-processor for pose estimation tasks."""
"""Image pre-processor for pose estimation tasks.
Comparing with the :class:`ImgDataPreprocessor`,
1. It will additionally append batch_input_shape
to data_samples considering the DETR-based pose estimation tasks.
2. Add a 'pillow backend' pipeline based normalize operation, convert
np.array to PIL.Image, and normalize it through torchvision.
3. Support image augmentation transforms on batched data.
It provides the data pre-processing as follows
- Collate and move data to the target device.
- Pad inputs to the maximum size of current batch with defined
``pad_value``. The padding size can be divisible by a defined
``pad_size_divisor``
- Stack inputs to batch_inputs.
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
- Normalize image with defined std and mean.
- Apply batch augmentation transforms.
Args:
mean (sequence of float, optional): The pixel mean of R, G, B
channels. Defaults to None.
std (sequence of float, optional): The pixel standard deviation
of R, G, B channels. Defaults to None.
pad_size_divisor (int): The size of padded image should be
divisible by ``pad_size_divisor``. Defaults to 1.
pad_value (float or int): The padded pixel value. Defaults to 0.
bgr_to_rgb (bool): whether to convert image from BGR to RGB.
Defaults to False.
rgb_to_bgr (bool): whether to convert image from RGB to BGR.
Defaults to False.
non_blocking (bool): Whether block current process
when transferring data to device. Defaults to False.
batch_augments: (list of dict, optional): Configs of augmentation
transforms on batched data. Defaults to None.
normalize_bakend (str): choose the normalize backend
in ['cv2', 'pillow']
"""

def __init__(self,
mean: Sequence[float] = None,
Expand All @@ -28,7 +68,8 @@ def __init__(self,
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False,
non_blocking: Optional[bool] = False,
batch_augments: Optional[List[dict]] = None):
batch_augments: Optional[List[dict]] = None,
normalize_bakend: str = 'cv2'):
super().__init__(
mean=mean,
std=std,
Expand All @@ -37,6 +78,13 @@ def __init__(self,
bgr_to_rgb=bgr_to_rgb,
rgb_to_bgr=rgb_to_bgr,
non_blocking=non_blocking)

assert normalize_bakend in ('cv2', 'pillow'), f'the argument ' \
f'`normalize_bakend` mush be either \'cv2\' or \'pillow\', ' \
f'but got \'{normalize_bakend}\'.'

self.normalize_bakend = normalize_bakend

if batch_augments is not None:
self.batch_augments = nn.ModuleList(
[MODELS.build(aug) for aug in batch_augments])
Expand All @@ -55,15 +103,33 @@ def forward(self, data: dict, training: bool = False) -> dict:
dict: Data in the same format as the model input.
"""
batch_pad_shape = self._get_pad_shape(data)
data = super().forward(data=data, training=training)
data = self.preprocess(data=data, training=training)
inputs, data_samples = data['inputs'], data['data_samples']

# update metainfo since the image shape might change
batch_input_shape = tuple(inputs[0].size()[-2:])
for data_sample, pad_shape in zip(data_samples, batch_pad_shape):
data_sample.set_metainfo({

aux_metainfo = {
'batch_input_shape': batch_input_shape,
'pad_shape': pad_shape
})
}

if 'input_size' not in data_sample.metainfo:
aux_metainfo['input_size'] = data_sample.img_shape

if 'input_center' not in data_sample.metainfo \
or 'input_scale' not in data_sample.metainfo:

w, h = data_sample.ori_shape
center = np.array([w / 2, h / 2], dtype=np.float32)
scale = np.array([w, h], dtype=np.float32)
aux_metainfo['input_center'] = center
aux_metainfo['input_scale'] = scale

data_sample.set_metainfo(aux_metainfo)

# apply batch augmentations
if training and self.batch_augments is not None:
for batch_aug in self.batch_augments:
inputs, data_samples = batch_aug(inputs, data_samples)
Expand Down Expand Up @@ -104,105 +170,9 @@ def _get_pad_shape(self, data: dict) -> List[tuple]:
f'{type(data)}: {data}')
return batch_pad_shape


@MODELS.register_module()
class BatchShapeDataPreprocessor(ImgDataPreprocessor):
"""Image pre-processor for pose estimation tasks.
Comparing with the :class:`PoseDataPreprocessor`,
1. It will additionally append batch_input_shape
to data_samples considering the DETR-based pose estimation tasks.
2. Add a 'pillow backend' pipeline based normalize operation, convert
np.array to PIL.Image, and normalize it through torchvision.
It provides the data pre-processing as follows
- Collate and move data to the target device.
- Pad inputs to the maximum size of current batch with defined
``pad_value``. The padding size can be divisible by a defined
``pad_size_divisor``
- Stack inputs to batch_inputs.
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
- Normalize image with defined std and mean.
Args:
- mean (Sequence[Number], optional): The pixel mean of R, G, B
channels. Defaults to None.
- std (Sequence[Number], optional): The pixel standard deviation
of R, G, B channels. Defaults to None.
- pad_size_divisor (int): The size of padded image should be
divisible by ``pad_size_divisor``. Defaults to 1.
- pad_value (Number): The padded pixel value. Defaults to 0.
- bgr_to_rgb (bool): whether to convert image from BGR to RGB.
Defaults to False.
- rgb_to_bgr (bool): whether to convert image from RGB to RGB.
Defaults to False.
- non_blocking (bool): Whether block current process
when transferring data to device. Defaults to False.
- normalize_bakend (str): choose the normalize backend
in ['cv2', 'pillow']
"""

def __init__(self,
mean: Sequence[Number] = None,
std: Sequence[Number] = None,
pad_size_divisor: int = 1,
pad_value: Union[float, int] = 0,
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False,
non_blocking: Optional[bool] = False,
normalize_bakend: str = 'cv2'):
super().__init__(
mean=mean,
std=std,
pad_size_divisor=pad_size_divisor,
pad_value=pad_value,
bgr_to_rgb=bgr_to_rgb,
rgb_to_bgr=rgb_to_bgr,
non_blocking=non_blocking)
self.normalize_bakend = normalize_bakend

def forward(self, data: dict, training: bool = False) -> dict:
"""Perform normalization, padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (dict): Data sampled from dataloader.
training (bool): Whether to enable training time augmentation.
Returns:
dict: Data in the same format as the model input.
"""
if self.normalize_bakend == 'cv2':
data = super().forward(data=data, training=training)
else:
data = self.normalize_pillow(data=data, training=training)

inputs, data_samples = data['inputs'], data['data_samples']

if data_samples is not None:
# NOTE the batched image size information may be useful, e.g.
# in DETR, this is needed for the construction of masks, which is
# then used for the transformer_head.
batch_input_shape = tuple(inputs[0].size()[-2:])
for data_sample in data_samples:

w, h = data_sample.ori_shape
center = np.array([w / 2, h / 2], dtype=np.float32)
scale = np.array([w, h], dtype=np.float32)
data_sample.set_metainfo({
'batch_input_shape': batch_input_shape,
'input_size': data_sample.img_shape,
'input_center': center,
'input_scale': scale
})
return {'inputs': inputs, 'data_samples': data_samples}

def normalize_pillow(self,
data: dict,
training: bool = False) -> Union[dict, list]:
def preprocess(self,
data: dict,
training: bool = False) -> Union[dict, list]:

data = self.cast_data(data) # type: ignore
_batch_inputs = data['inputs']
Expand All @@ -214,13 +184,16 @@ def normalize_pillow(self,
if self._channel_conversion:
_batch_input = _batch_input[[2, 1, 0], ...]

_batch_input_array = _batch_input.detach().cpu().numpy(
).transpose(1, 2, 0)
assert _batch_input_array.dtype == np.uint8, \
'Pillow backend only support uint8 type'
pil_image = Image.fromarray(_batch_input_array)
_batch_input = torchvision.transforms.functional.to_tensor(
pil_image).to(_batch_input.device)
if self.normalize_bakend == 'cv2':
_batch_input = _batch_input.float()
elif self.normalize_bakend == 'pillow':
_batch_input_array = _batch_input.detach().cpu().numpy(
).transpose(1, 2, 0)
assert _batch_input_array.dtype == np.uint8, \
'Pillow backend only support uint8 type'
pil_image = Image.fromarray(_batch_input_array)
_batch_input = tvF.to_tensor(pil_image).to(
_batch_input.device)

# Normalization.
if self._enable_normalize:
Expand All @@ -230,8 +203,11 @@ def normalize_pillow(self,
'If the mean has 3 values, the input tensor '
'should in shape of (3, H, W), but got the tensor '
f'with shape {_batch_input.shape}')
_batch_input = torchvision.transforms.functional.normalize(
_batch_input, mean=self.mean, std=self.std)
if self.normalize_bakend == 'cv2':
_batch_input = (_batch_input - self.mean) / self.std
elif self.normalize_bakend == 'pillow':
_batch_input = tvF.normalize(
_batch_input, mean=self.mean, std=self.std)
batch_inputs.append(_batch_input)
# Pad and stack Tensor.
batch_inputs = stack_batch(batch_inputs, self.pad_size_divisor,
Expand All @@ -246,19 +222,25 @@ def normalize_pillow(self,
_batch_inputs = _batch_inputs[:, [2, 1, 0], ...]
# Convert to float after channel conversion to ensure
# efficiency
_batch_inputs_array = _batch_inputs.detach().cpu().numpy(
).transpose(0, 2, 3, 1)
assert _batch_inputs.dtype == np.uint8, \
'Pillow backend only support uint8 type'
pil_image = Image.fromarray(_batch_inputs_array)
_batch_inputs = torchvision.transforms.functional.to_tensor(
pil_image).to(_batch_inputs.device)
if self.normalize_bakend == 'cv2':
_batch_input = _batch_input.float()
elif self.normalize_bakend == 'pillow':
_batch_inputs_array = _batch_inputs.detach().cpu().numpy(
).transpose(0, 2, 3, 1)
assert _batch_inputs.dtype == np.uint8, \
'Pillow backend only support uint8 type'
pil_image = Image.fromarray(_batch_inputs_array)
_batch_inputs = tvF.to_tensor(pil_image).to(
_batch_inputs.device)

if self._enable_normalize:
_batch_inputs = torchvision.transforms.functional.normalize(
_batch_inputs,
mean=(self.mean / 255).tolist(),
std=(self.std / 255).tolist())
if self.normalize_bakend == 'cv2':
_batch_inputs = (_batch_inputs - self.mean) / self.std
elif self.normalize_bakend == 'pillow':
_batch_inputs = tvF.normalize(
_batch_inputs,
mean=(self.mean / 255).tolist(),
std=(self.std / 255).tolist())

h, w = _batch_inputs.shape[2:]
target_h = math.ceil(
Expand Down

0 comments on commit 9ecf22f

Please sign in to comment.