diff --git a/configs/textrecog/_base_/default_runtime.py b/configs/textrecog/_base_/default_runtime.py index 6564fcb84..9c81c2fcf 100644 --- a/configs/textrecog/_base_/default_runtime.py +++ b/configs/textrecog/_base_/default_runtime.py @@ -46,3 +46,5 @@ type='TextRecogLocalVisualizer', name='visualizer', vis_backends=vis_backends) + +tta_model = dict(type='EncoderDecoderRecognizerTTAModel') diff --git a/configs/textrecog/abinet/_base_abinet-vision.py b/configs/textrecog/abinet/_base_abinet-vision.py index ef9a482f3..66954ff85 100644 --- a/configs/textrecog/abinet/_base_abinet-vision.py +++ b/configs/textrecog/abinet/_base_abinet-vision.py @@ -116,3 +116,50 @@ type='PackTextRecogInputs', meta_keys=('img_path', 'ori_shape', 'img_shape', 'valid_ratio')) ] + +tta_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='TestTimeAug', + transforms=[ + [ + dict( + type='ConditionApply', + true_transforms=[ + dict( + type='ImgAugWrapper', + args=[dict(cls='Rot90', k=0, keep_size=False)]) + ], + condition="results['img_shape'][1]>> tta_model = dict( + >>> type='EncoderDecoderRecognizerTTAModel') + >>> + >>> tta_pipeline = [ + >>> dict( + >>> type='LoadImageFromFile', + >>> color_type='grayscale', + >>> file_client_args=file_client_args), + >>> dict( + >>> type='TestTimeAug', + >>> transforms=[ + >>> [ + >>> dict( + >>> type='ConditionApply', + >>> true_transforms=[ + >>> dict( + >>> type='ImgAugWrapper', + >>> args=[dict(cls='Rot90', k=0, keep_size=False)]) # noqa: E501 + >>> ], + >>> condition="results['img_shape'][1]>> ), + >>> dict( + >>> type='ConditionApply', + >>> true_transforms=[ + >>> dict( + >>> type='ImgAugWrapper', + >>> args=[dict(cls='Rot90', k=1, keep_size=False)]) # noqa: E501 + >>> ], + >>> condition="results['img_shape'][1]>> ), + >>> dict( + >>> type='ConditionApply', + >>> true_transforms=[ + >>> dict( + >>> type='ImgAugWrapper', + >>> args=[dict(cls='Rot90', k=3, keep_size=False)]) + >>> ], + >>> condition="results['img_shape'][1]>> ), + >>> ], + >>> [ + >>> dict( + >>> type='RescaleToHeight', + >>> height=32, + >>> min_width=32, + >>> max_width=None, + >>> width_divisor=16) + >>> ], + >>> # add loading annotation after ``Resize`` because ground truth + >>> # does not need to do resize data transform + >>> [dict(type='LoadOCRAnnotations', with_text=True)], + >>> [ + >>> dict( + >>> type='PackTextRecogInputs', + >>> meta_keys=('img_path', 'ori_shape', 'img_shape', + >>> 'valid_ratio')) + >>> ] + >>> ]) + >>> ] + """ + + def merge_preds(self, + data_samples_list: List[RecSampleList]) -> RecSampleList: + """Merge predictions of enhanced data to one prediction. + + Args: + data_samples_list (List[RecSampleList]): List of predictions of + all enhanced data. The shape of data_samples_list is (B, M), + where B is the batch size and M is the number of augmented + data. + + Returns: + RecSampleList: Merged prediction. + """ + predictions = list() + for data_samples in data_samples_list: + scores = [ + data_sample.pred_text.score for data_sample in data_samples + ] + average_scores = np.array( + [sum(score) / max(1, len(score)) for score in scores]) + max_idx = np.argmax(average_scores) + predictions.append(data_samples[max_idx]) + return predictions diff --git a/tests/test_models/test_textrecog/test_recognizers/test_encoder_decoder_recognizer_tta.py b/tests/test_models/test_textrecog/test_recognizers/test_encoder_decoder_recognizer_tta.py new file mode 100644 index 000000000..2c2da3f86 --- /dev/null +++ b/tests/test_models/test_textrecog/test_recognizers/test_encoder_decoder_recognizer_tta.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch +import torch.nn as nn +from mmengine.structures import LabelData + +from mmocr.models.textrecog.recognizers import EncoderDecoderRecognizerTTAModel +from mmocr.structures import TextRecogDataSample + + +class DummyModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + return x + + def test_step(self, x): + return self.forward(x) + + +class TestEncoderDecoderRecognizerTTAModel(TestCase): + + def test_merge_preds(self): + + data_sample1 = TextRecogDataSample( + pred_text=LabelData( + score=torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]), text='abcde')) + data_sample2 = TextRecogDataSample( + pred_text=LabelData( + score=torch.tensor([0.2, 0.3, 0.4, 0.5, 0.6]), text='bcdef')) + data_sample3 = TextRecogDataSample( + pred_text=LabelData( + score=torch.tensor([0.3, 0.4, 0.5, 0.6, 0.7]), text='cdefg')) + aug_data_samples = [data_sample1, data_sample2, data_sample3] + batch_aug_data_samples = [aug_data_samples] * 3 + model = EncoderDecoderRecognizerTTAModel(module=DummyModel()) + preds = model.merge_preds(batch_aug_data_samples) + for pred in preds: + self.assertEqual(pred.pred_text.text, 'cdefg') diff --git a/tools/test.py b/tools/test.py index 3699e99a9..04b5d2c61 100755 --- a/tools/test.py +++ b/tools/test.py @@ -45,6 +45,8 @@ def parse_args(): choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='Job launcher') + parser.add_argument( + '--tta', action='store_true', help='Test time augmentation') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: @@ -107,6 +109,13 @@ def main(): if args.show or args.show_dir: cfg = trigger_visualization_hook(cfg, args) + cfg.load_from = args.checkpoint + + if args.tta: + cfg.test_dataloader.dataset.pipeline = cfg.tta_pipeline + cfg.tta_model.module = cfg.model + cfg.model = cfg.tta_model + # save predictions if args.save_preds: dump_metric = dict(