From cc5f24878b365d48262be8506b6d84304cbae9c1 Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Wed, 26 Jul 2023 20:16:27 +0800 Subject: [PATCH 01/21] add badcase hook --- mmpose/engine/hooks/badcase_hook.py | 168 ++++++++++++++++++++++++++++ 1 file changed, 168 insertions(+) create mode 100644 mmpose/engine/hooks/badcase_hook.py diff --git a/mmpose/engine/hooks/badcase_hook.py b/mmpose/engine/hooks/badcase_hook.py new file mode 100644 index 0000000000..1a245399a8 --- /dev/null +++ b/mmpose/engine/hooks/badcase_hook.py @@ -0,0 +1,168 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import warnings +from typing import Optional, Sequence + +import mmcv +import mmengine +import mmengine.fileio as fileio +from mmengine.hooks import Hook +from mmengine.runner import Runner +from mmengine.visualization import Visualizer + +from mmpose.registry import HOOKS +from mmpose.structures import PoseDataSample, merge_data_samples + + +@HOOKS.register_module() +class BadCaseAnalyzeHook(Hook): + """Bad Case Analyze Hook. Used to visualize validation and + testing process prediction results. + + In the testing phase: + + 1. If ``show`` is True, it means that only the prediction results are + visualized without storing data, so ``vis_backends`` needs to + be excluded. + 2. If ``out_dir`` is specified, it means that the prediction results + need to be saved to ``out_dir``. In order to avoid vis_backends + also storing data, so ``vis_backends`` needs to be excluded. + 3. ``vis_backends`` takes effect if the user does not specify ``show`` + and `out_dir``. You can set ``vis_backends`` to WandbVisBackend or + TensorboardVisBackend to store the prediction result in Wandb or + Tensorboard. + + Args: + enable (bool): whether to draw prediction results. If it is False, + it means that no drawing will be done. Defaults to False. + interval (int): The interval of visualization. Defaults to 50. + score_thr (float): The threshold to visualize the bboxes + and masks. Defaults to 0.3. + show (bool): Whether to display the drawn image. Default to False. + wait_time (float): The interval of show (s). Defaults to 0. + out_dir (str, optional): directory where painted images + will be saved in testing process. + backend_args (dict, optional): Arguments to instantiate the preifx of + uri corresponding backend. Defaults to None. + """ + + def __init__( + self, + enable: bool = False, + interval: int = 50, + kpt_thr: float = 0.3, + show: bool = False, + wait_time: float = 0., + out_dir: Optional[str] = None, + backend_args: Optional[dict] = None, + ): + self._visualizer: Visualizer = Visualizer.get_current_instance() + self.interval = interval + self.kpt_thr = kpt_thr + self.show = show + if self.show: + # No need to think about vis backends. + self._visualizer._vis_backends = {} + warnings.warn('The show is True, it means that only ' + 'the prediction results are visualized ' + 'without storing data, so vis_backends ' + 'needs to be excluded.') + + self.wait_time = wait_time + self.enable = enable + self.out_dir = out_dir + self._test_index = 0 + self.backend_args = backend_args + + def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[PoseDataSample]) -> None: + """Run after every ``self.interval`` validation iterations. + + Args: + runner (:obj:`Runner`): The runner of the validation process. + batch_idx (int): The index of the current batch in the val loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`PoseDataSample`]): Outputs from model. + """ + if self.enable is False: + return + + self._visualizer.set_dataset_meta(runner.val_evaluator.dataset_meta) + + # There is no guarantee that the same batch of images + # is visualized for each evaluation. + total_curr_iter = runner.iter + batch_idx + + # Visualize only the first data + img_path = data_batch['data_samples'][0].get('img_path') + img_bytes = fileio.get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + data_sample = outputs[0] + + # revert the heatmap on the original image + data_sample = merge_data_samples([data_sample]) + + if total_curr_iter % self.interval == 0: + self._visualizer.add_datasample( + os.path.basename(img_path) if self.show else 'val_img', + img, + data_sample=data_sample, + draw_gt=False, + draw_bbox=True, + draw_heatmap=True, + show=self.show, + wait_time=self.wait_time, + kpt_thr=self.kpt_thr, + step=total_curr_iter) + + def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, + outputs: Sequence[PoseDataSample]) -> None: + """Run after every testing iterations. + + Args: + runner (:obj:`Runner`): The runner of the testing process. + batch_idx (int): The index of the current batch in the test loop. + data_batch (dict): Data from dataloader. + outputs (Sequence[:obj:`PoseDataSample`]): Outputs from model. + """ + if self.enable is False: + return + + if self.out_dir is not None: + self.out_dir = os.path.join(runner.work_dir, runner.timestamp, + self.out_dir) + mmengine.mkdir_or_exist(self.out_dir) + + self._visualizer.set_dataset_meta(runner.test_evaluator.dataset_meta) + + for data_sample in outputs: + self._test_index += 1 + + img_path = data_sample.get('img_path') + img_bytes = fileio.get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + data_sample = merge_data_samples([data_sample]) + + out_file = None + if self.out_dir is not None: + out_file_name, postfix = os.path.basename(img_path).rsplit( + '.', 1) + index = len([ + fname for fname in os.listdir(self.out_dir) + if fname.startswith(out_file_name) + ]) + out_file = f'{out_file_name}_{index}.{postfix}' + out_file = os.path.join(self.out_dir, out_file) + + self._visualizer.add_datasample( + os.path.basename(img_path) if self.show else 'test_img', + img, + data_sample=data_sample, + show=self.show, + draw_gt=False, + draw_bbox=True, + draw_heatmap=True, + wait_time=self.wait_time, + kpt_thr=self.kpt_thr, + out_file=out_file, + step=self._test_index) From cc9e29e06dfbfaa7f6d93de2304efd3bb4b162db Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Sat, 29 Jul 2023 16:27:53 +0800 Subject: [PATCH 02/21] add loss based badcase analyze --- ...-hm_mobilenetv2_8xb64-210e_mpii-256x256.py | 7 +- mmpose/engine/hooks/__init__.py | 3 +- mmpose/engine/hooks/badcase_hook.py | 160 +++++++++++------- tools/test.py | 50 ++++++ 4 files changed, 155 insertions(+), 65 deletions(-) diff --git a/configs/body_2d_keypoint/topdown_heatmap/mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256.py b/configs/body_2d_keypoint/topdown_heatmap/mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256.py index 41b9d3ba9b..82cee697d9 100644 --- a/configs/body_2d_keypoint/topdown_heatmap/mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256.py +++ b/configs/body_2d_keypoint/topdown_heatmap/mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256.py @@ -27,7 +27,12 @@ auto_scale_lr = dict(base_batch_size=512) # hooks -default_hooks = dict(checkpoint=dict(save_best='PCK', rule='greater')) +default_hooks = dict(checkpoint=dict(save_best='PCK', rule='greater'), + badcase=dict(type="BadCaseAnalyzeHook", + metric_type="loss", + show=True, + # metric_type="accuracy", + out_dir='badcase')) # codec settings codec = dict( diff --git a/mmpose/engine/hooks/__init__.py b/mmpose/engine/hooks/__init__.py index dadb9c5f91..90ba316a8f 100644 --- a/mmpose/engine/hooks/__init__.py +++ b/mmpose/engine/hooks/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ema_hook import ExpMomentumEMA from .visualization_hook import PoseVisualizationHook +from .badcase_hook import BadCaseAnalyzeHook -__all__ = ['PoseVisualizationHook', 'ExpMomentumEMA'] +__all__ = ['PoseVisualizationHook', 'ExpMomentumEMA', 'BadCaseAnalyzeHook'] diff --git a/mmpose/engine/hooks/badcase_hook.py b/mmpose/engine/hooks/badcase_hook.py index 1a245399a8..7e9bd6a06e 100644 --- a/mmpose/engine/hooks/badcase_hook.py +++ b/mmpose/engine/hooks/badcase_hook.py @@ -1,7 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. import os +import json +import torch import warnings -from typing import Optional, Sequence +import numpy as np +from typing import Optional, Sequence, Dict import mmcv import mmengine @@ -10,7 +13,7 @@ from mmengine.runner import Runner from mmengine.visualization import Visualizer -from mmpose.registry import HOOKS +from mmpose.registry import HOOKS, MODELS, METRICS from mmpose.structures import PoseDataSample, merge_data_samples @@ -49,12 +52,15 @@ class BadCaseAnalyzeHook(Hook): def __init__( self, enable: bool = False, - interval: int = 50, - kpt_thr: float = 0.3, show: bool = False, wait_time: float = 0., + interval: int = 50, + kpt_thr: float = 0.3, out_dir: Optional[str] = None, backend_args: Optional[dict] = None, + metric_type: str = 'loss', + metric: dict = dict(type='KeypointMSELoss'), + badcase_thr: float = 5, ): self._visualizer: Visualizer = Visualizer.get_current_instance() self.interval = interval @@ -74,46 +80,29 @@ def __init__( self._test_index = 0 self.backend_args = backend_args - def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, - outputs: Sequence[PoseDataSample]) -> None: - """Run after every ``self.interval`` validation iterations. + self.metric_type = metric_type + self.metric = MODELS.build(metric) if metric_type == 'loss' else METRICS.build(metric) + self.metric_name = metric.type + self.badcase_thr = badcase_thr + self.results = [] + + def check_badcase(self, preds, gts): + """Check whether the sample is a badcase Args: - runner (:obj:`Runner`): The runner of the validation process. - batch_idx (int): The index of the current batch in the val loop. - data_batch (dict): Data from dataloader. - outputs (Sequence[:obj:`PoseDataSample`]): Outputs from model. + gts (np.ndarray): gts of the sample + preds (np.ndarray): preds of the sample + Return: + is_badcase (bool): whether the sample is a badcase or not + metric_value (float) """ - if self.enable is False: - return - - self._visualizer.set_dataset_meta(runner.val_evaluator.dataset_meta) - - # There is no guarantee that the same batch of images - # is visualized for each evaluation. - total_curr_iter = runner.iter + batch_idx - - # Visualize only the first data - img_path = data_batch['data_samples'][0].get('img_path') - img_bytes = fileio.get(img_path, backend_args=self.backend_args) - img = mmcv.imfrombytes(img_bytes, channel_order='rgb') - data_sample = outputs[0] - - # revert the heatmap on the original image - data_sample = merge_data_samples([data_sample]) - - if total_curr_iter % self.interval == 0: - self._visualizer.add_datasample( - os.path.basename(img_path) if self.show else 'val_img', - img, - data_sample=data_sample, - draw_gt=False, - draw_bbox=True, - draw_heatmap=True, - show=self.show, - wait_time=self.wait_time, - kpt_thr=self.kpt_thr, - step=total_curr_iter) + if self.metric_type == 'loss': + with torch.no_grad(): + metric_value = self.metric(torch.tensor(preds), torch.tensor(gts)).item() + is_badcase = metric_value >= self.badcase_thr + else: + is_badcase = metric_value <= self.badcase_thr + return is_badcase, metric_value def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, outputs: Sequence[PoseDataSample]) -> None: @@ -143,26 +132,71 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, img = mmcv.imfrombytes(img_bytes, channel_order='rgb') data_sample = merge_data_samples([data_sample]) - out_file = None - if self.out_dir is not None: - out_file_name, postfix = os.path.basename(img_path).rsplit( + gts = data_sample.gt_instances.keypoints + preds = data_sample.pred_instances.keypoints + is_badcase, metric_value = self.check_badcase(gts, preds) + + if is_badcase: + img_name, postfix = os.path.basename(img_path).rsplit( '.', 1) - index = len([ - fname for fname in os.listdir(self.out_dir) - if fname.startswith(out_file_name) - ]) - out_file = f'{out_file_name}_{index}.{postfix}' - out_file = os.path.join(self.out_dir, out_file) - - self._visualizer.add_datasample( - os.path.basename(img_path) if self.show else 'test_img', - img, - data_sample=data_sample, - show=self.show, - draw_gt=False, - draw_bbox=True, - draw_heatmap=True, - wait_time=self.wait_time, - kpt_thr=self.kpt_thr, - out_file=out_file, - step=self._test_index) + bboxes = data_sample.gt_instances.bboxes.astype(int).tolist() + bbox_info = 'bbox' + str(bboxes) + metric_postfix = self.metric_name + str(round(metric_value, 2)) + + self.results.append({'img': img_name, + 'bbox': bboxes, + self.metric_name: metric_value}) + + badcase_name = f'{img_name}_{bbox_info}_{metric_postfix}' + + out_file = None + if self.out_dir is not None: + out_file = f'{badcase_name}.{postfix}' + out_file = os.path.join(self.out_dir, out_file) + + # draw gt keypoints in blue color + self._visualizer.kpt_color[:, 0:3] = np.array([0, 0, 255]) + img_gt_drawn = self._visualizer.add_datasample( + badcase_name if self.show else 'test_img', + img, + data_sample=data_sample, + show=False, + draw_pred=False, + draw_gt=True, + draw_bbox=False, + draw_heatmap=False, + wait_time=self.wait_time, + kpt_thr=self.kpt_thr, + out_file=None, + step=self._test_index) + # draw pred keypoints in red color + self._visualizer.kpt_color[:, 0:3] = np.array([255, 0, 0]) + self._visualizer.add_datasample( + badcase_name if self.show else 'test_img', + img_gt_drawn, + data_sample=data_sample, + show=self.show, + draw_pred=True, + draw_gt=False, + draw_bbox=True, + draw_heatmap=False, + wait_time=self.wait_time, + kpt_thr=self.kpt_thr, + out_file=out_file, + step=self._test_index) + + def after_test_epoch(self, + runner, + metrics: Optional[Dict[str, float]] = None) -> None: + """All subclasses should override this method, if they need any + operations after each test epoch. + + Args: + runner (Runner): The runner of the testing process. + metrics (Dict[str, float], optional): Evaluation results of all + metrics on test dataset. The keys are the names of the + metrics, and the values are corresponding results. + """ + out_file = os.path.join(self.out_dir, 'results.json') + with open(out_file, 'w') as f: + json.dump(self.results, f) diff --git a/tools/test.py b/tools/test.py index 5dc0110260..8a39af7d36 100644 --- a/tools/test.py +++ b/tools/test.py @@ -52,6 +52,29 @@ def parse_args(): default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument( + '--badcase', + action='store_true', + help='whether analyze badcase in test') + # parser.add_argument( + # '--badcase-dir', + # type=str, + # default='badcase, + # help='directory where the badcases visulization and list will be saved') + # parser.add_argument( + # '--badcase-show', + # action='store_true', + # help='whether to display the badcases in a window.') + # parser.add_argument( + # '--badcase-metric', + # type=str, + # default='wrong_num', + # help='the metric to decide badcase.') + # parser.add_argument( + # '--badcase-thr', + # type=float, + # default=5.0, + # help='the min metric value to be a badcase.') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) @@ -73,6 +96,9 @@ def merge_args(cfg, args): # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) + + # if args.show and args.badcase_show: + # raise ValueError('Do not support pred and badcase visualization at the same time') # -------------------- visualization -------------------- if args.show or (args.show_dir is not None): @@ -88,6 +114,30 @@ def merge_args(cfg, args): cfg.default_hooks.visualization.out_dir = args.show_dir cfg.default_hooks.visualization.interval = args.interval + # -------------------- badcase analyze -------------------- + if args.badcase: + assert 'badcase' in cfg.default_hooks, \ + 'BadcaseAnalyzeHook is not set in the ' \ + '`default_hooks` field of config. Please set ' \ + '`badcase=dict(type="BadcaseAnalyzeHook")`' + + cfg.default_hooks.badcase.enable = True + badcase_show = cfg.default_hooks.badcase.get('show', 'False') + if badcase_show: + cfg.default_hooks.badcase.wait_time = args.wait_time + cfg.default_hooks.badcase.interval = args.interval + + metric_type = cfg.default_hooks.badcase.get('metric_type', 'loss') + if metric_type not in ['loss', 'accuracy']: + raise ValueError("Only support badcase metric type in ['loss', 'accuracy']") + + if metric_type == 'loss': + if not cfg.default_hooks.badcase.get('metric'): + cfg.default_hooks.badcase.metric = cfg.model.head.loss + else: + if not cfg.default_hooks.badcase.get('metric'): + cfg.default_hooks.badcase.metric = cfg.test_evaluator + # -------------------- Dump predictions -------------------- if args.dump is not None: assert args.dump.endswith(('.pkl', '.pickle')), \ From c6218378064cf3c575d0c7a2b3f1fd17d3d730d0 Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Sun, 30 Jul 2023 15:04:14 +0800 Subject: [PATCH 03/21] support accurayc based badcase analyze --- mmpose/engine/hooks/badcase_hook.py | 45 ++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/mmpose/engine/hooks/badcase_hook.py b/mmpose/engine/hooks/badcase_hook.py index 7e9bd6a06e..ab368ce0c7 100644 --- a/mmpose/engine/hooks/badcase_hook.py +++ b/mmpose/engine/hooks/badcase_hook.py @@ -38,15 +38,18 @@ class BadCaseAnalyzeHook(Hook): Args: enable (bool): whether to draw prediction results. If it is False, it means that no drawing will be done. Defaults to False. - interval (int): The interval of visualization. Defaults to 50. - score_thr (float): The threshold to visualize the bboxes - and masks. Defaults to 0.3. show (bool): Whether to display the drawn image. Default to False. wait_time (float): The interval of show (s). Defaults to 0. + interval (int): The interval of visualization. Defaults to 50. + kpt_thr (float): The threshold to visualize the keypoints. Defaults to 0.3. out_dir (str, optional): directory where painted images will be saved in testing process. backend_args (dict, optional): Arguments to instantiate the preifx of uri corresponding backend. Defaults to None. + metric_type (str): the mretic type to decide a badcase, loss or accuracy. + metric (dict): The config of metric. + metric_key (str): key of needed metric value in the return dict from class 'metric'. + badcase_thr (float): min loss or max accuracy for a badcase. """ def __init__( @@ -60,6 +63,7 @@ def __init__( backend_args: Optional[dict] = None, metric_type: str = 'loss', metric: dict = dict(type='KeypointMSELoss'), + metric_key: str = 'PCK', badcase_thr: float = 5, ): self._visualizer: Visualizer = Visualizer.get_current_instance() @@ -81,26 +85,41 @@ def __init__( self.backend_args = backend_args self.metric_type = metric_type - self.metric = MODELS.build(metric) if metric_type == 'loss' else METRICS.build(metric) - self.metric_name = metric.type + if metric_type not in ['loss', 'accuracy']: + raise KeyError( + f'The badcase metric type {metric_type} is not supported by ' + f"{self.__class__.__name__}. Should be one of 'loss', " + f"'accuracy', but got {metric_type}.") + self.metric = MODELS.build(metric) if metric_type == 'loss'\ + else METRICS.build(metric) + self.metric_name = metric.type if metric_type == 'loss'\ + else metric_key + self.metric_key = metric_key self.badcase_thr = badcase_thr self.results = [] - def check_badcase(self, preds, gts): + def check_badcase(self, data_batch, data_sample): """Check whether the sample is a badcase Args: - gts (np.ndarray): gts of the sample - preds (np.ndarray): preds of the sample + data_batch (Sequence[dict]): A batch of data + from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from + the model. Return: is_badcase (bool): whether the sample is a badcase or not metric_value (float) """ if self.metric_type == 'loss': + gts = data_sample.gt_instances.keypoints + preds = data_sample.pred_instances.keypoints with torch.no_grad(): - metric_value = self.metric(torch.tensor(preds), torch.tensor(gts)).item() + metric_value = self.metric(torch.tensor(preds), + torch.tensor(gts)).item() is_badcase = metric_value >= self.badcase_thr else: + self.metric.process([data_batch], [data_sample.to_dict()]) + metric_value = self.metric.evaluate(1)[self.metric_key] is_badcase = metric_value <= self.badcase_thr return is_badcase, metric_value @@ -132,9 +151,7 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, img = mmcv.imfrombytes(img_bytes, channel_order='rgb') data_sample = merge_data_samples([data_sample]) - gts = data_sample.gt_instances.keypoints - preds = data_sample.pred_instances.keypoints - is_badcase, metric_value = self.check_badcase(gts, preds) + is_badcase, metric_value = self.check_badcase(data_batch, data_sample) if is_badcase: img_name, postfix = os.path.basename(img_path).rsplit( @@ -156,6 +173,7 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, # draw gt keypoints in blue color self._visualizer.kpt_color[:, 0:3] = np.array([0, 0, 255]) + self._visualizer.link_color[:, 0:3] = np.array([0, 0, 255]) img_gt_drawn = self._visualizer.add_datasample( badcase_name if self.show else 'test_img', img, @@ -171,6 +189,7 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, step=self._test_index) # draw pred keypoints in red color self._visualizer.kpt_color[:, 0:3] = np.array([255, 0, 0]) + self._visualizer.link_color[:, 0:3] = np.array([255, 0, 0]) self._visualizer.add_datasample( badcase_name if self.show else 'test_img', img_gt_drawn, @@ -198,5 +217,5 @@ def after_test_epoch(self, metrics, and the values are corresponding results. """ out_file = os.path.join(self.out_dir, 'results.json') - with open(out_file, 'w') as f: + with open(out_file, 'w') as f: json.dump(self.results, f) From e1eadcf37b88b2673566e7f8563603938230c5b0 Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Sun, 30 Jul 2023 15:30:10 +0800 Subject: [PATCH 04/21] fix configdict bug --- mmpose/engine/hooks/badcase_hook.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mmpose/engine/hooks/badcase_hook.py b/mmpose/engine/hooks/badcase_hook.py index ab368ce0c7..5cab3044c8 100644 --- a/mmpose/engine/hooks/badcase_hook.py +++ b/mmpose/engine/hooks/badcase_hook.py @@ -9,6 +9,7 @@ import mmcv import mmengine import mmengine.fileio as fileio +from mmengine.config import ConfigDict from mmengine.hooks import Hook from mmengine.runner import Runner from mmengine.visualization import Visualizer @@ -47,7 +48,7 @@ class BadCaseAnalyzeHook(Hook): backend_args (dict, optional): Arguments to instantiate the preifx of uri corresponding backend. Defaults to None. metric_type (str): the mretic type to decide a badcase, loss or accuracy. - metric (dict): The config of metric. + metric (ConfigDict): The config of metric. metric_key (str): key of needed metric value in the return dict from class 'metric'. badcase_thr (float): min loss or max accuracy for a badcase. """ @@ -62,7 +63,7 @@ def __init__( out_dir: Optional[str] = None, backend_args: Optional[dict] = None, metric_type: str = 'loss', - metric: dict = dict(type='KeypointMSELoss'), + metric: ConfigDict = ConfigDict(type='KeypointMSELoss'), metric_key: str = 'PCK', badcase_thr: float = 5, ): From 3496be936e036d34de053fec6da33f8f6631666d Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Sun, 30 Jul 2023 15:32:34 +0800 Subject: [PATCH 05/21] revert cfg --- .../mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/configs/body_2d_keypoint/topdown_heatmap/mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256.py b/configs/body_2d_keypoint/topdown_heatmap/mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256.py index 82cee697d9..41b9d3ba9b 100644 --- a/configs/body_2d_keypoint/topdown_heatmap/mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256.py +++ b/configs/body_2d_keypoint/topdown_heatmap/mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256.py @@ -27,12 +27,7 @@ auto_scale_lr = dict(base_batch_size=512) # hooks -default_hooks = dict(checkpoint=dict(save_best='PCK', rule='greater'), - badcase=dict(type="BadCaseAnalyzeHook", - metric_type="loss", - show=True, - # metric_type="accuracy", - out_dir='badcase')) +default_hooks = dict(checkpoint=dict(save_best='PCK', rule='greater')) # codec settings codec = dict( From bfd17073742a1ebdfdd1fe4871d412a131ef7858 Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Sun, 30 Jul 2023 15:33:05 +0800 Subject: [PATCH 06/21] add badcase analyze sample cfg --- ...lenetv2_8xb64-210e_mpii-256x256_badcase.py | 124 ++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 configs/body_2d_keypoint/topdown_heatmap/mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256_badcase.py diff --git a/configs/body_2d_keypoint/topdown_heatmap/mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256_badcase.py b/configs/body_2d_keypoint/topdown_heatmap/mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256_badcase.py new file mode 100644 index 0000000000..69af5571e3 --- /dev/null +++ b/configs/body_2d_keypoint/topdown_heatmap/mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256_badcase.py @@ -0,0 +1,124 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +# runtime +train_cfg = dict(max_epochs=210, val_interval=10) + +# optimizer +optim_wrapper = dict(optimizer=dict( + type='Adam', + lr=5e-4, +)) + +# learning policy +param_scheduler = [ + dict( + type='LinearLR', begin=0, end=500, start_factor=0.001, + by_epoch=False), # warm-up + dict( + type='MultiStepLR', + begin=0, + end=210, + milestones=[170, 200], + gamma=0.1, + by_epoch=True) +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=512) + +# hooks +default_hooks = dict(checkpoint=dict(save_best='PCK', rule='greater'), + badcase=dict(type="BadCaseAnalyzeHook", + # metric_type="loss", + metric_type="accuracy", + show=True, + badcase_thr=100, + out_dir='badcase')) + +# codec settings +codec = dict( + type='MSRAHeatmap', input_size=(256, 256), heatmap_size=(64, 64), sigma=2) + +# model settings +model = dict( + type='TopdownPoseEstimator', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + type='MobileNetV2', + widen_factor=1., + out_indices=(7, ), + init_cfg=dict(type='Pretrained', checkpoint='mmcls://mobilenet_v2'), + ), + head=dict( + type='HeatmapHead', + in_channels=1280, + out_channels=16, + loss=dict(type='KeypointMSELoss', use_target_weight=True), + decoder=codec), + test_cfg=dict( + flip_test=True, + flip_mode='heatmap', + shift_heatmap=True, + )) + +# base dataset settings +dataset_type = 'MpiiDataset' +data_mode = 'topdown' +data_root = 'data/mpii/' + +# pipelines +train_pipeline = [ + dict(type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomBBoxTransform', shift_prob=0), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='PackPoseInputs') +] + +# data loaders +train_dataloader = dict( + batch_size=64, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/mpii_train.json', + data_prefix=dict(img='images/'), + pipeline=train_pipeline, + )) +val_dataloader = dict( + batch_size=32, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/mpii_val.json', + headbox_file='data/mpii/annotations/mpii_gt_val.mat', + data_prefix=dict(img='images/'), + test_mode=True, + pipeline=val_pipeline, + )) +test_dataloader = val_dataloader + +# evaluators +val_evaluator = dict(type='MpiiPCKAccuracy') +test_evaluator = val_evaluator From c658f8780b78d4747a9c89c10c83fa543241317b Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Sun, 30 Jul 2023 16:07:36 +0800 Subject: [PATCH 07/21] support draw_line with str value color --- mmpose/visualization/opencv_backend_visualizer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmpose/visualization/opencv_backend_visualizer.py b/mmpose/visualization/opencv_backend_visualizer.py index 1c17506640..f125207928 100644 --- a/mmpose/visualization/opencv_backend_visualizer.py +++ b/mmpose/visualization/opencv_backend_visualizer.py @@ -358,7 +358,8 @@ def draw_lines(self, **kwargs) elif self.backend == 'opencv': - + if isinstance(colors, str): + colors = mmcv.color_val(colors) self._image = cv2.line( self._image, (x_datas[0], y_datas[0]), (x_datas[1], y_datas[1]), From 5773b69d6c5b36af39c53656936badf8c84db56a Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Sun, 30 Jul 2023 16:45:02 +0800 Subject: [PATCH 08/21] add unit test for badcase hook --- .../test_hooks/test_badcase_hook.py | 96 +++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 tests/test_engine/test_hooks/test_badcase_hook.py diff --git a/tests/test_engine/test_hooks/test_badcase_hook.py b/tests/test_engine/test_hooks/test_badcase_hook.py new file mode 100644 index 0000000000..cbeb186e7a --- /dev/null +++ b/tests/test_engine/test_hooks/test_badcase_hook.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +import shutil +import time +from unittest import TestCase +from unittest.mock import MagicMock + +import numpy as np +from mmengine.config import ConfigDict +from mmengine.structures import InstanceData + +from mmpose.engine.hooks import BadCaseAnalyzeHook +from mmpose.structures import PoseDataSample +from mmpose.visualization import PoseLocalVisualizer + + +def _rand_poses(num_boxes, kpt_num, h, w): + center = np.random.rand(num_boxes, 2) + offset = np.random.rand(num_boxes, kpt_num, 2) / 2.0 + + pose = center[:, None, :] + offset.clip(0, 1) + pose[:, :, 0] *= w + pose[:, :, 1] *= h + + return pose + + +class TestBadCaseHook(TestCase): + + def setUp(self) -> None: + kpt_num = 16 + PoseLocalVisualizer.get_instance('test_badcase_hook') + + data_sample = PoseDataSample() + data_sample.set_metainfo({ + 'img_path': + osp.join( + osp.dirname(__file__), '../../data/coco/000000000785.jpg') + }) + self.data_batch = {'data_samples': [data_sample] * 2} + + pred_det_data_sample = data_sample.clone() + pred_instances = InstanceData() + pred_instances.keypoints = _rand_poses(1, kpt_num, 10, 12) + pred_det_data_sample.pred_instances = pred_instances + + gt_instances = InstanceData() + gt_instances.keypoints = _rand_poses(1, kpt_num, 10, 12) + gt_instances.keypoints_visible = np.ones((1, kpt_num)) + gt_instances.head_size = np.random.rand(1, 1) + gt_instances.bboxes = np.random.rand(1, 4) + pred_det_data_sample.gt_instances = gt_instances + self.outputs = [pred_det_data_sample] * 2 + + def test_after_test_iter(self): + runner = MagicMock() + runner.iter = 1 + + # test + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + out_dir = timestamp + '1' + runner.work_dir = timestamp + runner.timestamp = '1' + hook = BadCaseAnalyzeHook(enable=False, out_dir=out_dir) + hook.after_test_iter(runner, 1, self.data_batch, self.outputs) + self.assertTrue(not osp.exists(f'{timestamp}/1/{out_dir}')) + + hook = BadCaseAnalyzeHook(enable=True, metric_type="loss", + metric=ConfigDict(type='KeypointMSELoss'), + badcase_thr=-1, # is_badcase = True + out_dir=out_dir) + hook.after_test_iter(runner, 1, self.data_batch, self.outputs) + self.assertEqual(hook._test_index, 2) + self.assertTrue(osp.exists(f'{timestamp}/1/{out_dir}')) + # same image and preds/gts, so onlu one file + self.assertTrue(len(os.listdir(f'{timestamp}/1/{out_dir}')) == 1) + + hook.after_test_epoch(runner) + self.assertTrue(osp.exists(f'{timestamp}/1/{out_dir}/results.json')) + shutil.rmtree(f'{timestamp}') + + hook = BadCaseAnalyzeHook(enable=True, metric_type="accuracy", + metric=ConfigDict(type='MpiiPCKAccuracy'), + badcase_thr=-1, # is_badcase = False + out_dir=out_dir) + hook.after_test_iter(runner, 1, self.data_batch, self.outputs) + self.assertTrue(osp.exists(f'{timestamp}/1/{out_dir}')) + self.assertTrue(len(os.listdir(f'{timestamp}/1/{out_dir}')) == 0) + shutil.rmtree(f'{timestamp}') + + +if __name__ == "__main__": + test = TestBadCaseHook() + test.setUp() + test.test_after_test_iter() From 8e297294c742089509b6954121a6a56ded3eac25 Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Sun, 30 Jul 2023 16:45:31 +0800 Subject: [PATCH 09/21] use str based color --- mmpose/engine/hooks/badcase_hook.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmpose/engine/hooks/badcase_hook.py b/mmpose/engine/hooks/badcase_hook.py index 5cab3044c8..080e70136f 100644 --- a/mmpose/engine/hooks/badcase_hook.py +++ b/mmpose/engine/hooks/badcase_hook.py @@ -173,8 +173,8 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, out_file = os.path.join(self.out_dir, out_file) # draw gt keypoints in blue color - self._visualizer.kpt_color[:, 0:3] = np.array([0, 0, 255]) - self._visualizer.link_color[:, 0:3] = np.array([0, 0, 255]) + self._visualizer.kpt_color = 'blue' + self._visualizer.link_color = 'blue' img_gt_drawn = self._visualizer.add_datasample( badcase_name if self.show else 'test_img', img, @@ -189,8 +189,8 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, out_file=None, step=self._test_index) # draw pred keypoints in red color - self._visualizer.kpt_color[:, 0:3] = np.array([255, 0, 0]) - self._visualizer.link_color[:, 0:3] = np.array([255, 0, 0]) + self._visualizer.kpt_color = 'red' + self._visualizer.link_color = 'red' self._visualizer.add_datasample( badcase_name if self.show else 'test_img', img_gt_drawn, From 4cfc7ab8eabd12a4a00fa4f537d3df606b4fa3d0 Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Sun, 30 Jul 2023 16:46:14 +0800 Subject: [PATCH 10/21] rm useless codes and add warnings --- tools/test.py | 29 ++++++----------------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/tools/test.py b/tools/test.py index 8a39af7d36..80e7e409ca 100644 --- a/tools/test.py +++ b/tools/test.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse import os +import warnings import os.path as osp import mmengine @@ -56,25 +57,6 @@ def parse_args(): '--badcase', action='store_true', help='whether analyze badcase in test') - # parser.add_argument( - # '--badcase-dir', - # type=str, - # default='badcase, - # help='directory where the badcases visulization and list will be saved') - # parser.add_argument( - # '--badcase-show', - # action='store_true', - # help='whether to display the badcases in a window.') - # parser.add_argument( - # '--badcase-metric', - # type=str, - # default='wrong_num', - # help='the metric to decide badcase.') - # parser.add_argument( - # '--badcase-thr', - # type=float, - # default=5.0, - # help='the min metric value to be a badcase.') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: os.environ['LOCAL_RANK'] = str(args.local_rank) @@ -96,9 +78,6 @@ def merge_args(cfg, args): # use config filename as default work_dir if cfg.work_dir is None cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) - - # if args.show and args.badcase_show: - # raise ValueError('Do not support pred and badcase visualization at the same time') # -------------------- visualization -------------------- if args.show or (args.show_dir is not None): @@ -125,11 +104,15 @@ def merge_args(cfg, args): badcase_show = cfg.default_hooks.badcase.get('show', 'False') if badcase_show: cfg.default_hooks.badcase.wait_time = args.wait_time + if args.show: + warnings.warn("Enabling both pred and badcase" + "visualiztion can be confusing") cfg.default_hooks.badcase.interval = args.interval metric_type = cfg.default_hooks.badcase.get('metric_type', 'loss') if metric_type not in ['loss', 'accuracy']: - raise ValueError("Only support badcase metric type in ['loss', 'accuracy']") + raise ValueError("Only support badcase metric type" + "in ['loss', 'accuracy']") if metric_type == 'loss': if not cfg.default_hooks.badcase.get('metric'): From 54eecc3c5d52599ac4cddd0c9c8a06f725c3a41b Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Tue, 1 Aug 2023 23:02:45 +0800 Subject: [PATCH 11/21] move badcase hook config to default_runtime.py --- configs/_base_/default_runtime.py | 5 + ...lenetv2_8xb64-210e_mpii-256x256_badcase.py | 124 ------------------ 2 files changed, 5 insertions(+), 124 deletions(-) delete mode 100644 configs/body_2d_keypoint/topdown_heatmap/mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256_badcase.py diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py index 561d574fa7..c2e9d848f3 100644 --- a/configs/_base_/default_runtime.py +++ b/configs/_base_/default_runtime.py @@ -8,6 +8,11 @@ checkpoint=dict(type='CheckpointHook', interval=10), sampler_seed=dict(type='DistSamplerSeedHook'), visualization=dict(type='PoseVisualizationHook', enable=False), + badcase=dict(type="BadCaseAnalyzeHook", + metric_type="loss", + badcase_thr=100, + show=True, + out_dir='badcase') ) # custom hooks diff --git a/configs/body_2d_keypoint/topdown_heatmap/mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256_badcase.py b/configs/body_2d_keypoint/topdown_heatmap/mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256_badcase.py deleted file mode 100644 index 69af5571e3..0000000000 --- a/configs/body_2d_keypoint/topdown_heatmap/mpii/td-hm_mobilenetv2_8xb64-210e_mpii-256x256_badcase.py +++ /dev/null @@ -1,124 +0,0 @@ -_base_ = ['../../../_base_/default_runtime.py'] - -# runtime -train_cfg = dict(max_epochs=210, val_interval=10) - -# optimizer -optim_wrapper = dict(optimizer=dict( - type='Adam', - lr=5e-4, -)) - -# learning policy -param_scheduler = [ - dict( - type='LinearLR', begin=0, end=500, start_factor=0.001, - by_epoch=False), # warm-up - dict( - type='MultiStepLR', - begin=0, - end=210, - milestones=[170, 200], - gamma=0.1, - by_epoch=True) -] - -# automatically scaling LR based on the actual training batch size -auto_scale_lr = dict(base_batch_size=512) - -# hooks -default_hooks = dict(checkpoint=dict(save_best='PCK', rule='greater'), - badcase=dict(type="BadCaseAnalyzeHook", - # metric_type="loss", - metric_type="accuracy", - show=True, - badcase_thr=100, - out_dir='badcase')) - -# codec settings -codec = dict( - type='MSRAHeatmap', input_size=(256, 256), heatmap_size=(64, 64), sigma=2) - -# model settings -model = dict( - type='TopdownPoseEstimator', - data_preprocessor=dict( - type='PoseDataPreprocessor', - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375], - bgr_to_rgb=True), - backbone=dict( - type='MobileNetV2', - widen_factor=1., - out_indices=(7, ), - init_cfg=dict(type='Pretrained', checkpoint='mmcls://mobilenet_v2'), - ), - head=dict( - type='HeatmapHead', - in_channels=1280, - out_channels=16, - loss=dict(type='KeypointMSELoss', use_target_weight=True), - decoder=codec), - test_cfg=dict( - flip_test=True, - flip_mode='heatmap', - shift_heatmap=True, - )) - -# base dataset settings -dataset_type = 'MpiiDataset' -data_mode = 'topdown' -data_root = 'data/mpii/' - -# pipelines -train_pipeline = [ - dict(type='LoadImage'), - dict(type='GetBBoxCenterScale'), - dict(type='RandomFlip', direction='horizontal'), - dict(type='RandomBBoxTransform', shift_prob=0), - dict(type='TopdownAffine', input_size=codec['input_size']), - dict(type='GenerateTarget', encoder=codec), - dict(type='PackPoseInputs') -] -val_pipeline = [ - dict(type='LoadImage'), - dict(type='GetBBoxCenterScale'), - dict(type='TopdownAffine', input_size=codec['input_size']), - dict(type='PackPoseInputs') -] - -# data loaders -train_dataloader = dict( - batch_size=64, - num_workers=2, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=True), - dataset=dict( - type=dataset_type, - data_root=data_root, - data_mode=data_mode, - ann_file='annotations/mpii_train.json', - data_prefix=dict(img='images/'), - pipeline=train_pipeline, - )) -val_dataloader = dict( - batch_size=32, - num_workers=2, - persistent_workers=True, - drop_last=False, - sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), - dataset=dict( - type=dataset_type, - data_root=data_root, - data_mode=data_mode, - ann_file='annotations/mpii_val.json', - headbox_file='data/mpii/annotations/mpii_gt_val.mat', - data_prefix=dict(img='images/'), - test_mode=True, - pipeline=val_pipeline, - )) -test_dataloader = val_dataloader - -# evaluators -val_evaluator = dict(type='MpiiPCKAccuracy') -test_evaluator = val_evaluator From 271d10557993dbfe2b3a3175c954d69b56590044 Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Thu, 3 Aug 2023 15:36:46 +0800 Subject: [PATCH 12/21] rename badcase hook and fix linting --- configs/_base_/default_runtime.py | 12 ++++---- mmpose/engine/hooks/badcase_hook.py | 48 ++++++++++++++++------------- 2 files changed, 32 insertions(+), 28 deletions(-) diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py index c2e9d848f3..f278c32fe2 100644 --- a/configs/_base_/default_runtime.py +++ b/configs/_base_/default_runtime.py @@ -8,12 +8,12 @@ checkpoint=dict(type='CheckpointHook', interval=10), sampler_seed=dict(type='DistSamplerSeedHook'), visualization=dict(type='PoseVisualizationHook', enable=False), - badcase=dict(type="BadCaseAnalyzeHook", - metric_type="loss", - badcase_thr=100, - show=True, - out_dir='badcase') -) + badcase=dict( + type='BadCaseAnalysisHook', + metric_type='loss', + badcase_thr=100, + show=True, + out_dir='badcase')) # custom hooks custom_hooks = [ diff --git a/mmpose/engine/hooks/badcase_hook.py b/mmpose/engine/hooks/badcase_hook.py index 080e70136f..3378eb8985 100644 --- a/mmpose/engine/hooks/badcase_hook.py +++ b/mmpose/engine/hooks/badcase_hook.py @@ -1,27 +1,26 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os import json -import torch +import os import warnings -import numpy as np -from typing import Optional, Sequence, Dict +from typing import Dict, Optional, Sequence import mmcv import mmengine import mmengine.fileio as fileio +import torch from mmengine.config import ConfigDict from mmengine.hooks import Hook from mmengine.runner import Runner from mmengine.visualization import Visualizer -from mmpose.registry import HOOKS, MODELS, METRICS +from mmpose.registry import HOOKS, METRICS, MODELS from mmpose.structures import PoseDataSample, merge_data_samples @HOOKS.register_module() -class BadCaseAnalyzeHook(Hook): - """Bad Case Analyze Hook. Used to visualize validation and - testing process prediction results. +class BadCaseAnalysisHook(Hook): + """Bad Case Analyze Hook. Used to visualize validation and testing process + prediction results. In the testing phase: @@ -42,14 +41,17 @@ class BadCaseAnalyzeHook(Hook): show (bool): Whether to display the drawn image. Default to False. wait_time (float): The interval of show (s). Defaults to 0. interval (int): The interval of visualization. Defaults to 50. - kpt_thr (float): The threshold to visualize the keypoints. Defaults to 0.3. + kpt_thr (float): The threshold to visualize the keypoints. + Defaults to 0.3. out_dir (str, optional): directory where painted images will be saved in testing process. backend_args (dict, optional): Arguments to instantiate the preifx of uri corresponding backend. Defaults to None. - metric_type (str): the mretic type to decide a badcase, loss or accuracy. + metric_type (str): the mretic type to decide a badcase, + loss or accuracy. metric (ConfigDict): The config of metric. - metric_key (str): key of needed metric value in the return dict from class 'metric'. + metric_key (str): key of needed metric value in the return dict + from class 'metric'. badcase_thr (float): min loss or max accuracy for a badcase. """ @@ -100,7 +102,7 @@ def __init__( self.results = [] def check_badcase(self, data_batch, data_sample): - """Check whether the sample is a badcase + """Check whether the sample is a badcase. Args: data_batch (Sequence[dict]): A batch of data @@ -108,15 +110,15 @@ def check_badcase(self, data_batch, data_sample): data_samples (Sequence[dict]): A batch of outputs from the model. Return: - is_badcase (bool): whether the sample is a badcase or not + is_badcase (bool): whether the sample is a badcase or not metric_value (float) """ if self.metric_type == 'loss': gts = data_sample.gt_instances.keypoints preds = data_sample.pred_instances.keypoints with torch.no_grad(): - metric_value = self.metric(torch.tensor(preds), - torch.tensor(gts)).item() + metric_value = self.metric( + torch.tensor(preds), torch.tensor(gts)).item() is_badcase = metric_value >= self.badcase_thr else: self.metric.process([data_batch], [data_sample.to_dict()]) @@ -152,19 +154,21 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, img = mmcv.imfrombytes(img_bytes, channel_order='rgb') data_sample = merge_data_samples([data_sample]) - is_badcase, metric_value = self.check_badcase(data_batch, data_sample) + is_badcase, metric_value = self.check_badcase( + data_batch, data_sample) if is_badcase: - img_name, postfix = os.path.basename(img_path).rsplit( - '.', 1) + img_name, postfix = os.path.basename(img_path).rsplit('.', 1) bboxes = data_sample.gt_instances.bboxes.astype(int).tolist() bbox_info = 'bbox' + str(bboxes) metric_postfix = self.metric_name + str(round(metric_value, 2)) - self.results.append({'img': img_name, - 'bbox': bboxes, - self.metric_name: metric_value}) - + self.results.append({ + 'img': img_name, + 'bbox': bboxes, + self.metric_name: metric_value + }) + badcase_name = f'{img_name}_{bbox_info}_{metric_postfix}' out_file = None From ac462c7dd072e307258fc8dcb1d6a15eae775222 Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Thu, 3 Aug 2023 15:51:02 +0800 Subject: [PATCH 13/21] set and sort default cfg of badcase --- configs/_base_/default_runtime.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py index f278c32fe2..6f27c0345a 100644 --- a/configs/_base_/default_runtime.py +++ b/configs/_base_/default_runtime.py @@ -10,10 +10,10 @@ visualization=dict(type='PoseVisualizationHook', enable=False), badcase=dict( type='BadCaseAnalysisHook', + enable=False, + out_dir='badcase', metric_type='loss', - badcase_thr=100, - show=True, - out_dir='badcase')) + badcase_thr=5)) # custom hooks custom_hooks = [ From 7595b4f97f8bfe57cfdfe25bb493020846c47a5d Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Thu, 3 Aug 2023 16:17:11 +0800 Subject: [PATCH 14/21] update badcase or pred show logic --- tools/test.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tools/test.py b/tools/test.py index 80e7e409ca..f161c036d1 100644 --- a/tools/test.py +++ b/tools/test.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse import os -import warnings import os.path as osp import mmengine @@ -80,14 +79,15 @@ def merge_args(cfg, args): osp.splitext(osp.basename(args.config))[0]) # -------------------- visualization -------------------- - if args.show or (args.show_dir is not None): + if (args.show and not args.badcase) or (args.show_dir is not None): assert 'visualization' in cfg.default_hooks, \ 'PoseVisualizationHook is not set in the ' \ '`default_hooks` field of config. Please set ' \ '`visualization=dict(type="PoseVisualizationHook")`' cfg.default_hooks.visualization.enable = True - cfg.default_hooks.visualization.show = args.show + cfg.default_hooks.visualization.show = False \ + if args.badcase else args.show if args.show: cfg.default_hooks.visualization.wait_time = args.wait_time cfg.default_hooks.visualization.out_dir = args.show_dir @@ -99,21 +99,18 @@ def merge_args(cfg, args): 'BadcaseAnalyzeHook is not set in the ' \ '`default_hooks` field of config. Please set ' \ '`badcase=dict(type="BadcaseAnalyzeHook")`' - + cfg.default_hooks.badcase.enable = True - badcase_show = cfg.default_hooks.badcase.get('show', 'False') - if badcase_show: + cfg.default_hooks.badcase.show = args.show + if args.show: cfg.default_hooks.badcase.wait_time = args.wait_time - if args.show: - warnings.warn("Enabling both pred and badcase" - "visualiztion can be confusing") cfg.default_hooks.badcase.interval = args.interval metric_type = cfg.default_hooks.badcase.get('metric_type', 'loss') if metric_type not in ['loss', 'accuracy']: - raise ValueError("Only support badcase metric type" + raise ValueError('Only support badcase metric type' "in ['loss', 'accuracy']") - + if metric_type == 'loss': if not cfg.default_hooks.badcase.get('metric'): cfg.default_hooks.badcase.metric = cfg.model.head.loss From 2d68e40d7f0cbafc8473c7a7cd8f05c4ef385bb7 Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Thu, 3 Aug 2023 16:38:52 +0800 Subject: [PATCH 15/21] fix linting in test_badcase_hook.py --- .../test_hooks/test_badcase_hook.py | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/test_engine/test_hooks/test_badcase_hook.py b/tests/test_engine/test_hooks/test_badcase_hook.py index cbeb186e7a..564ffa9f61 100644 --- a/tests/test_engine/test_hooks/test_badcase_hook.py +++ b/tests/test_engine/test_hooks/test_badcase_hook.py @@ -44,7 +44,7 @@ def setUp(self) -> None: pred_instances = InstanceData() pred_instances.keypoints = _rand_poses(1, kpt_num, 10, 12) pred_det_data_sample.pred_instances = pred_instances - + gt_instances = InstanceData() gt_instances.keypoints = _rand_poses(1, kpt_num, 10, 12) gt_instances.keypoints_visible = np.ones((1, kpt_num)) @@ -66,31 +66,37 @@ def test_after_test_iter(self): hook.after_test_iter(runner, 1, self.data_batch, self.outputs) self.assertTrue(not osp.exists(f'{timestamp}/1/{out_dir}')) - hook = BadCaseAnalyzeHook(enable=True, metric_type="loss", - metric=ConfigDict(type='KeypointMSELoss'), - badcase_thr=-1, # is_badcase = True - out_dir=out_dir) + hook = BadCaseAnalyzeHook( + enable=True, + out_dir=out_dir, + metric_type='loss', + metric=ConfigDict(type='KeypointMSELoss'), + badcase_thr=-1, # is_badcase = True + ) hook.after_test_iter(runner, 1, self.data_batch, self.outputs) self.assertEqual(hook._test_index, 2) self.assertTrue(osp.exists(f'{timestamp}/1/{out_dir}')) # same image and preds/gts, so onlu one file self.assertTrue(len(os.listdir(f'{timestamp}/1/{out_dir}')) == 1) - + hook.after_test_epoch(runner) self.assertTrue(osp.exists(f'{timestamp}/1/{out_dir}/results.json')) shutil.rmtree(f'{timestamp}') - hook = BadCaseAnalyzeHook(enable=True, metric_type="accuracy", - metric=ConfigDict(type='MpiiPCKAccuracy'), - badcase_thr=-1, # is_badcase = False - out_dir=out_dir) + hook = BadCaseAnalyzeHook( + enable=True, + out_dir=out_dir, + metric_type='accuracy', + metric=ConfigDict(type='MpiiPCKAccuracy'), + badcase_thr=-1, # is_badcase = False + ) hook.after_test_iter(runner, 1, self.data_batch, self.outputs) self.assertTrue(osp.exists(f'{timestamp}/1/{out_dir}')) self.assertTrue(len(os.listdir(f'{timestamp}/1/{out_dir}')) == 0) shutil.rmtree(f'{timestamp}') -if __name__ == "__main__": +if __name__ == '__main__': test = TestBadCaseHook() test.setUp() test.test_after_test_iter() From 371a01fa6bca8ce4741c19014e1cf205e81073a5 Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Thu, 3 Aug 2023 16:46:36 +0800 Subject: [PATCH 16/21] fix rename bug --- mmpose/engine/hooks/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmpose/engine/hooks/__init__.py b/mmpose/engine/hooks/__init__.py index 90ba316a8f..abfe762881 100644 --- a/mmpose/engine/hooks/__init__.py +++ b/mmpose/engine/hooks/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .badcase_hook import BadCaseAnalysisHook from .ema_hook import ExpMomentumEMA from .visualization_hook import PoseVisualizationHook -from .badcase_hook import BadCaseAnalyzeHook -__all__ = ['PoseVisualizationHook', 'ExpMomentumEMA', 'BadCaseAnalyzeHook'] +__all__ = ['PoseVisualizationHook', 'ExpMomentumEMA', 'BadCaseAnalysisHook'] From 34aca1feb6f795672db100b24aa6fada262c95aa Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Thu, 3 Aug 2023 16:51:10 +0800 Subject: [PATCH 17/21] fix rename bug --- tests/test_engine/test_hooks/test_badcase_hook.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_engine/test_hooks/test_badcase_hook.py b/tests/test_engine/test_hooks/test_badcase_hook.py index 564ffa9f61..4a84506fa8 100644 --- a/tests/test_engine/test_hooks/test_badcase_hook.py +++ b/tests/test_engine/test_hooks/test_badcase_hook.py @@ -10,7 +10,7 @@ from mmengine.config import ConfigDict from mmengine.structures import InstanceData -from mmpose.engine.hooks import BadCaseAnalyzeHook +from mmpose.engine.hooks import BadCaseAnalysisHook from mmpose.structures import PoseDataSample from mmpose.visualization import PoseLocalVisualizer @@ -62,11 +62,11 @@ def test_after_test_iter(self): out_dir = timestamp + '1' runner.work_dir = timestamp runner.timestamp = '1' - hook = BadCaseAnalyzeHook(enable=False, out_dir=out_dir) + hook = BadCaseAnalysisHook(enable=False, out_dir=out_dir) hook.after_test_iter(runner, 1, self.data_batch, self.outputs) self.assertTrue(not osp.exists(f'{timestamp}/1/{out_dir}')) - hook = BadCaseAnalyzeHook( + hook = BadCaseAnalysisHook( enable=True, out_dir=out_dir, metric_type='loss', @@ -83,7 +83,7 @@ def test_after_test_iter(self): self.assertTrue(osp.exists(f'{timestamp}/1/{out_dir}/results.json')) shutil.rmtree(f'{timestamp}') - hook = BadCaseAnalyzeHook( + hook = BadCaseAnalysisHook( enable=True, out_dir=out_dir, metric_type='accuracy', From af3f475084e6cdcbacb0e1a0e55cfea592061621 Mon Sep 17 00:00:00 2001 From: Indigo6 <40358785+Indigo6@users.noreply.github.com> Date: Fri, 4 Aug 2023 09:01:36 +0800 Subject: [PATCH 18/21] Update mmpose/engine/hooks/badcase_hook.py Co-authored-by: Tau --- mmpose/engine/hooks/badcase_hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmpose/engine/hooks/badcase_hook.py b/mmpose/engine/hooks/badcase_hook.py index 3378eb8985..2dcc38ccdb 100644 --- a/mmpose/engine/hooks/badcase_hook.py +++ b/mmpose/engine/hooks/badcase_hook.py @@ -118,7 +118,7 @@ def check_badcase(self, data_batch, data_sample): preds = data_sample.pred_instances.keypoints with torch.no_grad(): metric_value = self.metric( - torch.tensor(preds), torch.tensor(gts)).item() + torch.from_numpy(preds), torch.from_numpy(gts), torch.from_numpy(weights)).item() is_badcase = metric_value >= self.badcase_thr else: self.metric.process([data_batch], [data_sample.to_dict()]) From 8607f123c6c2b3e594f583555fcaf17fe6449de0 Mon Sep 17 00:00:00 2001 From: Indigo6 <40358785+Indigo6@users.noreply.github.com> Date: Fri, 4 Aug 2023 09:01:51 +0800 Subject: [PATCH 19/21] Update mmpose/engine/hooks/badcase_hook.py Co-authored-by: Tau --- mmpose/engine/hooks/badcase_hook.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mmpose/engine/hooks/badcase_hook.py b/mmpose/engine/hooks/badcase_hook.py index 2dcc38ccdb..691ff31b2b 100644 --- a/mmpose/engine/hooks/badcase_hook.py +++ b/mmpose/engine/hooks/badcase_hook.py @@ -116,6 +116,7 @@ def check_badcase(self, data_batch, data_sample): if self.metric_type == 'loss': gts = data_sample.gt_instances.keypoints preds = data_sample.pred_instances.keypoints + weights = data_sample.gt_instances.keypoints_visible with torch.no_grad(): metric_value = self.metric( torch.from_numpy(preds), torch.from_numpy(gts), torch.from_numpy(weights)).item() From 6149d5aa71a65f3902f5c93230756058c6c39980 Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Fri, 4 Aug 2023 09:17:18 +0800 Subject: [PATCH 20/21] fix linting --- mmpose/engine/hooks/badcase_hook.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmpose/engine/hooks/badcase_hook.py b/mmpose/engine/hooks/badcase_hook.py index 691ff31b2b..92673806c9 100644 --- a/mmpose/engine/hooks/badcase_hook.py +++ b/mmpose/engine/hooks/badcase_hook.py @@ -119,7 +119,8 @@ def check_badcase(self, data_batch, data_sample): weights = data_sample.gt_instances.keypoints_visible with torch.no_grad(): metric_value = self.metric( - torch.from_numpy(preds), torch.from_numpy(gts), torch.from_numpy(weights)).item() + torch.from_numpy(preds), torch.from_numpy(gts), + torch.from_numpy(weights)).item() is_badcase = metric_value >= self.badcase_thr else: self.metric.process([data_batch], [data_sample.to_dict()]) From 725c1a181c2337250c3e389619c085caec4a0334 Mon Sep 17 00:00:00 2001 From: Indigo6 <1401322857@qq.com> Date: Fri, 4 Aug 2023 09:25:37 +0800 Subject: [PATCH 21/21] bgr2rgb after mmcv.color_val --- mmpose/visualization/opencv_backend_visualizer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mmpose/visualization/opencv_backend_visualizer.py b/mmpose/visualization/opencv_backend_visualizer.py index f125207928..9604d07fea 100644 --- a/mmpose/visualization/opencv_backend_visualizer.py +++ b/mmpose/visualization/opencv_backend_visualizer.py @@ -129,7 +129,7 @@ def draw_circles(self, **kwargs) elif self.backend == 'opencv': if isinstance(face_colors, str): - face_colors = mmcv.color_val(face_colors) + face_colors = mmcv.color_val(face_colors)[::-1] if alpha == 1.0: self._image = cv2.circle(self._image, @@ -247,7 +247,7 @@ def draw_texts( if bboxes is not None: bbox_color = bboxes[0]['facecolor'] if isinstance(bbox_color, str): - bbox_color = mmcv.color_val(bbox_color) + bbox_color = mmcv.color_val(bbox_color)[::-1] y = y - text_baseline // 2 self._image = cv2.rectangle( @@ -359,7 +359,7 @@ def draw_lines(self, elif self.backend == 'opencv': if isinstance(colors, str): - colors = mmcv.color_val(colors) + colors = mmcv.color_val(colors)[::-1] self._image = cv2.line( self._image, (x_datas[0], y_datas[0]), (x_datas[1], y_datas[1]),