Skip to content

Commit

Permalink
[Feature] support RTMPose Gradio app in projects (#2877)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis committed Dec 20, 2023
1 parent 7d1b3e2 commit b52b302
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ Models:
Weights: https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-body7_pt-body7_420e-384x288-65e718c4_20230504.pth
- Config: configs/body_2d_keypoint/rtmpose/body8/rtmpose-l_8xb256-420e_body8-384x288.py
In Collection: RTMPose
Alias: rtmpose-l
Metadata:
Architecture: *id001
Training Data: *id002
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Models:
Metadata:
Architecture: *id001
Training Data: COCO-WholeBody
Name: rtmpose-l_8xb32-270e_coco-wholebody-384x288.py
Name: rtmpose-l_8xb32-270e_coco-wholebody-384x288
Results:
- Dataset: COCO-WholeBody
Metrics:
Expand Down
8 changes: 7 additions & 1 deletion mmpose/apis/inferencers/utils/default_det_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
mmpose_path = get_installed_path(MODULE2PACKAGE['mmpose'])

default_det_models = dict(
human=dict(model='rtmdet-m', weights=None, cat_ids=(0, )),
human=dict(
model=osp.join(
mmpose_path, '.mim', 'demo/mmdetection_cfg/'
'rtmdet_m_640-8xb32_coco-person.py'),
weights='https://download.openmmlab.com/mmpose/v1/projects/'
'rtmposev1/rtmdet_m_8xb32-100e_coco-obj365-person-235e8209.pth',
cat_ids=(0, )),
face=dict(
model=osp.join(mmpose_path, '.mim',
'demo/mmdetection_cfg/yolox-s_8xb8-300e_coco-face.py'),
Expand Down
208 changes: 208 additions & 0 deletions projects/rtmpose/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright (c) OpenMMLab. All rights reserved.

import os
from functools import partial

import gradio as gr

# prepare environment
project_path = os.path.join(os.path.dirname(os.path.abspath(__file__)))
mmpose_path = project_path.split('/projects', 1)[0]

os.system('python -m pip install Openmim')
os.system('python -m mim install "mmcv>=2.0.0"')
os.system('python -m mim install "mmengine>=0.9.0"')
os.system('python -m mim install "mmdet>=3.0.0"')
os.system(f'python -m mim install -e {mmpose_path}')
from mmpose.apis import MMPoseInferencer # noqa

models = [
'rtmpose | body', 'rtmo | body', 'rtmpose | face', 'dwpose | wholebody',
'rtmw | wholebody'
]
cached_model = {model: None for model in models}


def predict(input,
draw_heatmap=False,
model_type='body',
skeleton_style='mmpose',
input_type='image'):
"""Visualize the demo images.
Using mmdet to detect the human.
"""

if model_type == 'rtmpose | face':
if cached_model[model_type] is None:
cached_model[model_type] = MMPoseInferencer(pose2d='face')
model = cached_model[model_type]

elif model_type == 'dwpose | wholebody':
if cached_model[model_type] is None:
cached_model[model_type] = MMPoseInferencer(
pose2d=os.path.join(
project_path, 'rtmpose/wholebody_2d_keypoint/'
'rtmpose-l_8xb32-270e_coco-wholebody-384x288.py'),
pose2d_weights='https://download.openmmlab.com/mmpose/v1/'
'projects/rtmposev1/rtmpose-l_simcc-ucoco_dw-ucoco_270e-'
'384x288-2438fd99_20230728.pth')
model = cached_model[model_type]

elif model_type == 'rtmw | wholebody':
if cached_model[model_type] is None:
cached_model[model_type] = MMPoseInferencer(
pose2d=os.path.join(
project_path, 'rtmpose/wholebody_2d_keypoint/'
'rtmw-l_8xb320-270e_cocktail14-384x288.py'),
pose2d_weights='https://download.openmmlab.com/mmpose/v1/'
'projects/rtmw/rtmw-dw-x-l_simcc-cocktail14_270e-'
'384x288-20231122.pth')
model = cached_model[model_type]

elif model_type == 'rtmpose | body':
if cached_model[model_type] is None:
cached_model[model_type] = MMPoseInferencer(pose2d='rtmpose-l')
model = cached_model[model_type]

elif model_type == 'rtmo | body':
if cached_model[model_type] is None:
cached_model[model_type] = MMPoseInferencer(pose2d='rtmo')
model = cached_model[model_type]
draw_heatmap = False

else:
raise ValueError

if input_type == 'image':

result = next(
model(
input,
return_vis=True,
draw_heatmap=draw_heatmap,
skeleton_style=skeleton_style))
img = result['visualization'][0][..., ::-1]
return img

elif input_type == 'video':

for _ in model(
input,
vis_out_dir='test.mp4',
draw_heatmap=draw_heatmap,
skeleton_style=skeleton_style):
pass

return 'test.mp4'

return None


news_list = [
'2023-8-1: We support [DWPose](https://arxiv.org/pdf/2307.15880.pdf).',
'2023-9-25: We release an alpha version of RTMW model, the technical '
'report will be released soon.',
'2023-12-11: Update RTMW models, the online version is the RTMW-l with '
'70.1 mAP on COCO-Wholebody.',
'2023-12-13: We release an alpha version of RTMO (One-stage RTMPose) '
'models.',
]

with gr.Blocks() as demo:

with gr.Tab('Upload-Image'):
input_img = gr.Image(type='numpy')
button = gr.Button('Inference', variant='primary')
hm = gr.Checkbox(label='draw-heatmap', info='Whether to draw heatmap')
model_type = gr.Dropdown([
'rtmpose | body', 'rtmo | body', 'rtmpose | face',
'dwpose | wholebody', 'rtmw | wholebody'
],
label='Model | Keypoint Type',
info='Body / Face / Wholebody')

gr.Markdown('## News')
for news in news_list[::-1]:
gr.Markdown(news)

gr.Markdown('## Output')
out_image = gr.Image(type='numpy')
gr.Examples(['./tests/data/coco/000000000785.jpg'], input_img)
input_type = 'image'
button.click(
partial(predict, input_type=input_type),
[input_img, hm, model_type], out_image)

with gr.Tab('Webcam-Image'):
input_img = gr.Image(source='webcam', type='numpy')
button = gr.Button('Inference', variant='primary')
hm = gr.Checkbox(label='draw-heatmap', info='Whether to draw heatmap')
model_type = gr.Dropdown([
'rtmpose | body', 'rtmo | body', 'rtmpose | face',
'dwpose | wholebody', 'rtmw | wholebody'
],
label='Model | Keypoint Type',
info='Body / Face / Wholebody')

gr.Markdown('## News')
for news in news_list[::-1]:
gr.Markdown(news)

gr.Markdown('## Output')
out_image = gr.Image(type='numpy')

input_type = 'image'
button.click(
partial(predict, input_type=input_type),
[input_img, hm, model_type], out_image)

with gr.Tab('Upload-Video'):
input_video = gr.Video(type='mp4')
button = gr.Button('Inference', variant='primary')
hm = gr.Checkbox(label='draw-heatmap', info='Whether to draw heatmap')
model_type = gr.Dropdown([
'rtmpose | body', 'rtmo | body', 'rtmpose | face',
'dwpose | wholebody', 'rtmw | wholebody'
],
label='Model | Keypoint type',
info='Body / Face / Wholebody')

gr.Markdown('## News')
for news in news_list[::-1]:
gr.Markdown(news)

gr.Markdown('## Output')
out_video = gr.Video()

input_type = 'video'
button.click(
partial(predict, input_type=input_type),
[input_video, hm, model_type], out_video)

with gr.Tab('Webcam-Video'):
input_video = gr.Video(source='webcam', format='mp4')
button = gr.Button('Inference', variant='primary')
hm = gr.Checkbox(label='draw-heatmap', info='Whether to draw heatmap')
model_type = gr.Dropdown([
'rtmpose | body', 'rtmo | body', 'rtmpose | face',
'dwpose | wholebody', 'rtmw | wholebody'
],
label='Model | Keypoint Type',
info='Body / Face / Wholebody')

gr.Markdown('## News')
for news in news_list[::-1]:
gr.Markdown(news)

gr.Markdown('## Output')
out_video = gr.Video()

input_type = 'video'
button.click(
partial(predict, input_type=input_type),
[input_video, hm, model_type], out_video)

gr.close_all()
demo.queue()
demo.launch()

0 comments on commit b52b302

Please sign in to comment.