Skip to content

Commit

Permalink
[Docs] Add advanced tutorial of implement new model. (#2539)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tau-J committed Jul 13, 2023
1 parent 45835ac commit d458942
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 43 deletions.
80 changes: 79 additions & 1 deletion docs/en/advanced_guides/implement_new_models.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,81 @@
# Implement New Models

Coming soon.
This tutorial will introduce how to implement your own models in MMPose. After summarizing, we split the need to implement new models into two categories:

1. Based on the algorithm paradigm supported by MMPose, customize the modules (backbone, neck, head, codec, etc.) in the model
2. Implement new algorithm paradigm

## Basic Concepts

What you want to implement is one of the above, and this section is important to you because it is the basic principle of building models in the OpenMMLab.

In MMPose, all the code related to the implementation of the model structure is stored in the [models directory](https://github.com/open-mmlab/mmpose/tree/main/mmpose/models) :

```shell
mmpose
|----models
|----backbones #
|----data_preprocessors # image normalization
|----heads #
|----losses # loss functions
|----necks #
|----pose_estimators # algorithm paradigm
|----utils #
```

You can refer to the following flow chart to locate the module you need to implement:

![image](https://github.com/open-mmlab/mmpose/assets/13503330/f4eeb99c-e2a1-4907-9d46-f110c51f0814)

## Pose Estimatiors

In pose estimatiors, we will define the inference process of a model, and decode the model output results in `predict()`, first transform it from `output space` to `input image space` using the [codec](./codecs.md), and then combine the meta information to transform to `original image space`.

![pose_estimator_en](https://github.com/open-mmlab/mmpose/assets/13503330/48c3813e-b977-4215-b5bc-e7379cfd2bce)

Currently, MMPose supports the following types of pose estimator:

1. [Top-down](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/pose_estimators/topdown.py): The input of the pose model is a cropped single target (animal, human body, human face, human hand, plant, clothes, etc.) image, and the output is the key point prediction result of the target
2. [Bottom-up](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/pose_estimators/bottomup.py): The input of the pose model is an image containing any number of targets, and the output is the key point prediction result of all targets in the image
3. [Pose Lifting](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/pose_estimators/pose_lifter.py): The input of the pose model is a 2D keypoint coordinate array, and the output is a 3D keypoint coordinate array

If the model you want to implement does not belong to the above algorithm paradigm, then you need to inherit the [BasePoseEstimator](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/pose_estimators/base.py) class to define your own algorithm paradigm.

## Backbones

If you want to implement a new backbone network, you need to create a new file in the [backbones directory](https://github.com/open-mmlab/mmpose/tree/main/mmpose/models/backbones) to define it.

The new backbone network needs to inherit the [BaseBackbone](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/backbones/base_backbone.py) class, and there is no difference in other aspects from inheriting `nn.Module` to create.

After completing the implementation of the backbone network, you need to use `MODELS` to register it:

```Python3
from mmpose.registry import MODELS
from .base_backbone import BaseBackbone


@MODELS.register_module()
class YourNewBackbone(BaseBackbone):
```

Finally, please remember to import your new backbone network in `[__init__.py](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/backbones/__init__.py)` .

## Heads

The addition of a new prediction head is similar to the backbone network process. You need to create a new file in the [heads directory](https://github.com/open-mmlab/mmpose/tree/main/mmpose/models/heads) to define it, and then inherit [BaseHead](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/heads/base_head.py) .

One thing to note is that in MMPose, the loss function is calculated in the Head. According to the different training and evaluation stages, `loss()` and `predict()` are executed respectively.

In `predict()`, the model will call the `decode()` method of the corresponding codec to transform the model output result from `output space` to `input image space`.

After completing the implementation of the prediction head, you need to use `MODELS` to register it:

```Python3
from mmpose.registry import MODELS
from ..base_head import BaseHead

@MODELS.register_module()
class YourNewHead(BaseHead):
```

Finally, please remember to import your new prediction head in `[__init__.py](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/heads/__init__.py)` .
38 changes: 18 additions & 20 deletions docs/en/guide_to_framework.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ The organization of data in MMPose contains:

### Dataset Meta Information

The meta information of a pose dataset usually includes the definition of keypoints and skeleton, symmetrical characteristic, and keypoint properties (e.g. belonging to upper or lower body, weights and sigmas). These information is important in data preprocessing, model training and evaluation. In MMpose, the dataset meta information is stored in configs files under `$MMPOSE/configs/_base_/datasets/`.
The meta information of a pose dataset usually includes the definition of keypoints and skeleton, symmetrical characteristic, and keypoint properties (e.g. belonging to upper or lower body, weights and sigmas). These information is important in data preprocessing, model training and evaluation. In MMpose, the dataset meta information is stored in configs files under [$MMPOSE/configs/_base_/datasets](https://github.com/open-mmlab/mmpose/tree/main/configs/_base_/datasets).

To use a custom dataset in MMPose, you need to add a new config file of the dataset meta information. Take the MPII dataset (`$MMPOSE/configs/_base_/datasets/mpii.py`) as an example. Here is its dataset information:
To use a custom dataset in MMPose, you need to add a new config file of the dataset meta information. Take the MPII dataset ([$MMPOSE/configs/_base_/datasets/mpii.py](https://github.com/open-mmlab/mmpose/blob/main/configs/_base_/datasets/mpii.py)) as an example. Here is its dataset information:

```Python
dataset_info = dict(
Expand Down Expand Up @@ -111,7 +111,7 @@ dataset_info = dict(
])
```

In the model config, the user needs to specify the metainfo path of the custom dataset (e.g. `$MMPOSE/configs/_base_/datasets/custom.py`) as follows:\`\`\`
In the model config, the user needs to specify the metainfo path of the custom dataset (e.g. `$MMPOSE/configs/_base_/datasets/custom.py`) as follows:

```python
# dataset and dataloader settings
Expand Down Expand Up @@ -148,17 +148,15 @@ test_dataloader = val_dataloader

To use custom dataset in MMPose, we recommend converting the annotations into a supported format (e.g. COCO or MPII) and directly using our implementation of the corresponding dataset. If this is not applicable, you may need to implement your own dataset class.

Most 2D keypoint datasets in MMPose **organize the annotations in a COCO-like style**. Thus we provide a base class [BaseCocoStyleDataset](mmpose/datasets/datasets/base/base_coco_style_dataset.py) for these datasets. We recommend that users subclass `BaseCocoStyleDataset` and override the methods as needed (usually `__init__()` and `_load_annotations()`) to extend to a new custom 2D keypoint dataset.
Most 2D keypoint datasets in MMPose **organize the annotations in a COCO-like style**. Thus we provide a base class [BaseCocoStyleDataset](https://github.com/open-mmlab/mmpose/blob/main/mmpose/datasets/datasets/base/base_coco_style_dataset.py) for these datasets. We recommend that users subclass [BaseCocoStyleDataset](https://github.com/open-mmlab/mmpose/blob/main/mmpose/datasets/datasets/base/base_coco_style_dataset.py) and override the methods as needed (usually `__init__()` and `_load_annotations()`) to extend to a new custom 2D keypoint dataset.

```{note}
Please refer to [COCO](./dataset_zoo/2d_body_keypoint.md) for more details about the COCO data format.
```

```{note}
The bbox format in MMPose is in `xyxy` instead of `xywh`, which is consistent with the format used in other OpenMMLab projects like [MMDetection](https://github.com/open-mmlab/mmdetection). We provide useful utils for bbox format conversion, such as `bbox_xyxy2xywh`, `bbox_xywh2xyxy`, `bbox_xyxy2cs`, etc., which are defined in `$MMPOSE/mmpose/structures/bbox/transforms.py`.
```
The bbox format in MMPose is in `xyxy` instead of `xywh`, which is consistent with the format used in other OpenMMLab projects like [MMDetection](https://github.com/open-mmlab/mmdetection). We provide useful utils for bbox format conversion, such as `bbox_xyxy2xywh`, `bbox_xywh2xyxy`, `bbox_xyxy2cs`, etc., which are defined in [$MMPOSE/mmpose/structures/bbox/transforms.py](https://github.com/open-mmlab/mmpose/blob/main/mmpose/structures/bbox/transforms.py).

Let's take the implementation of the MPII dataset (`$MMPOSE/mmpose/datasets/datasets/body/mpii_dataset.py`) as an example.
Let's take the implementation of the MPII dataset ([$MMPOSE/mmpose/datasets/datasets/body/mpii_dataset.py](https://github.com/open-mmlab/mmpose/blob/main/mmpose/datasets/datasets/body/mpii_dataset.py)) as an example.

```Python
@DATASETS.register_module()
Expand Down Expand Up @@ -264,7 +262,7 @@ class MpiiDataset(BaseCocoStyleDataset):

When supporting MPII dataset, since we need to use `head_size` to calculate `PCKh`, we add `headbox_file` to `__init__()` and override`_load_annotations()`.

To support a dataset that is beyond the scope of `BaseCocoStyleDataset`, you may need to subclass from the `BaseDataset` provided by [MMEngine](https://github.com/open-mmlab/mmengine). Please refer to the [documents](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/basedataset.html) for details.
To support a dataset that is beyond the scope of [BaseCocoStyleDataset](https://github.com/open-mmlab/mmpose/blob/main/mmpose/datasets/datasets/base/base_coco_style_dataset.py), you may need to subclass from the `BaseDataset` provided by [MMEngine](https://github.com/open-mmlab/mmengine). Please refer to the [documents](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/basedataset.html) for details.

### Pipeline

Expand Down Expand Up @@ -302,13 +300,13 @@ Here is a diagram to show the workflow of data transformation among the three sc

![migration-en](https://user-images.githubusercontent.com/13503330/187190213-cad87b5f-0a95-4f1f-b722-15896914ded4.png)

In MMPose, the modules used for data transformation are under `$MMPOSE/mmpose/datasets/transforms`, and their workflow is shown as follows:
In MMPose, the modules used for data transformation are under `[$MMPOSE/mmpose/datasets/transforms](https://github.com/open-mmlab/mmpose/tree/main/mmpose/datasets/transforms)`, and their workflow is shown as follows:

![transforms-en](https://user-images.githubusercontent.com/13503330/187190352-a7662346-b8da-4256-9192-c7a84b15cbb5.png)

#### i. Augmentation

Commonly used transforms are defined in `$MMPOSE/mmpose/datasets/transforms/common_transforms.py`, such as `RandomFlip`, `RandomHalfBody`, etc.
Commonly used transforms are defined in [$MMPOSE/mmpose/datasets/transforms/common_transforms.py](https://github.com/open-mmlab/mmpose/blob/main/mmpose/datasets/transforms/common_transforms.py), such as `RandomFlip`, `RandomHalfBody`, etc.

For top-down methods, `Shift`, `Rotate`and `Resize` are implemented by `RandomBBoxTransform`**.** For bottom-up methods, `BottomupRandomAffine` is used.

Expand Down Expand Up @@ -352,15 +350,15 @@ Note that we unify the data format of top-down and bottom-up methods, which mean

- Bottom-up: `[B, N, K, D]`

The provided codecs are stored under `$MMPOSE/mmpose/codecs`.
The provided codecs are stored under [$MMPOSE/mmpose/codecs](https://github.com/open-mmlab/mmpose/tree/main/mmpose/codecs).

```{note}
If you wish to customize a new codec, you can refer to [Codec](./user_guides/codecs.md) for more details.
```

#### iv. Packing

After the data is transformed, you need to pack it using `PackPoseInputs`.
After the data is transformed, you need to pack it using [PackPoseInputs](https://github.com/open-mmlab/mmpose/blob/main/mmpose/datasets/transforms/formatting.py).

This method converts the data stored in the dictionary `results` into standard data structures in MMPose, such as `InstanceData`, `PixelData`, `PoseDataSample`, etc.

Expand Down Expand Up @@ -425,7 +423,7 @@ In MMPose 1.0, the model consists of the following components:

- **Head**: used to implement the core algorithm and loss function

We define a base class `BasePoseEstimator` for the model in `$MMPOSE/models/pose_estimators/base.py`. All models, e.g. `TopdownPoseEstimator`, should inherit from this base class and override the corresponding methods.
We define a base class `BasePoseEstimator` for the model in [$MMPOSE/models/pose_estimators/base.py](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/pose_estimators/base.py). All models, e.g. `TopdownPoseEstimator`, should inherit from this base class and override the corresponding methods.

Three modes are provided in `forward()` of the estimator:

Expand Down Expand Up @@ -477,7 +475,7 @@ It will transpose the channel order of the input image from `bgr` to `rgb` and n

### Backbone

MMPose provides some commonly used backbones under `$MMPOSE/mmpose/models/backbones`.
MMPose provides some commonly used backbones under [$MMPOSE/mmpose/models/backbones](https://github.com/open-mmlab/mmpose/tree/main/mmpose/models/backbones).

In practice, developers often use pre-trained backbone weights for transfer learning, which can improve the performance of the model on small datasets.

Expand Down Expand Up @@ -515,7 +513,7 @@ It should be emphasized that if you add a new backbone, you need to register it
class YourBackbone(BaseBackbone):
```

Besides, import it in `$MMPOSE/mmpose/models/backbones/__init__.py`, and add it to `__all__`.
Besides, import it in [$MMPOSE/mmpose/models/backbones/\_\_init\_\_.py](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/backbones/__init__.py), and add it to `__all__`.

### Neck

Expand Down Expand Up @@ -559,21 +557,21 @@ Neck is usually a module between Backbone and Head, which is used in some algori

Generally speaking, Head is often the core of an algorithm, which is used to make predictions and perform loss calculation.

Modules related to Head in MMPose are defined under `$MMPOSE/mmpose/models/heads`, and developers need to inherit the base class `BaseHead` when customizing Head and override the following methods:
Modules related to Head in MMPose are defined under [$MMPOSE/mmpose/models/heads](https://github.com/open-mmlab/mmpose/tree/main/mmpose/models/heads), and developers need to inherit the base class `BaseHead` when customizing Head and override the following methods:

- forward()

- predict()

- loss()

Specifically, `predict()` method needs to return pose predictions in the image space, which is obtained from the model output though the decoding function provided by the codec. We implement this process in `BaseHead.decode()`.
Specifically, `predict()` method needs to return pose predictions in the image space, which is obtained from the model output though the decoding function provided by the codec. We implement this process in [BaseHead.decode()](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/heads/base_head.py).

On the other hand, we will perform test-time augmentation(TTA) in `predict()`.

A commonly used TTA is `flip_test`, namely, an image and its flipped version are sent into the model to inference, and the output of the flipped version will be flipped back, then average them to stabilize the prediction.

Here is an example of `predict()` in `RegressionHead`:
Here is an example of `predict()` in [RegressionHead](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/heads/regression_heads/regression_head.py):

```Python
def predict(self,
Expand Down Expand Up @@ -627,7 +625,7 @@ keypoint_weights = torch.cat([
])
```

Here is the complete implementation of `loss()` in `RegressionHead`:
Here is the complete implementation of `loss()` in [RegressionHead](https://github.com/open-mmlab/mmpose/blob/main/mmpose/models/heads/regression_heads/regression_head.py):

```Python
def loss(self,
Expand Down
Loading

0 comments on commit d458942

Please sign in to comment.