From 504b1376eaa7a06cfb297ed356c95fda25e36bf4 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Wed, 26 Oct 2022 22:17:03 +0800 Subject: [PATCH 1/7] add singan dataset --- mmedit/datasets/__init__.py | 2 + mmedit/datasets/singan_dataset.py | 139 ++++++++++++++++++ mmedit/datasets/transforms/formatting.py | 44 +++++- tests/test_datasets/test_singan_dataset.py | 31 ++++ .../test_transforms/test_formatting.py | 26 +++- 5 files changed, 233 insertions(+), 9 deletions(-) create mode 100644 mmedit/datasets/singan_dataset.py create mode 100644 tests/test_datasets/test_singan_dataset.py diff --git a/mmedit/datasets/__init__.py b/mmedit/datasets/__init__.py index 0744816ca6..dd8a74e5da 100644 --- a/mmedit/datasets/__init__.py +++ b/mmedit/datasets/__init__.py @@ -7,6 +7,7 @@ from .grow_scale_image_dataset import GrowScaleImgDataset from .imagenet_dataset import ImageNet from .paired_image_dataset import PairedImageDataset +from .singan_dataset import SinGANDataset from .unpaired_image_dataset import UnpairedImageDataset __all__ = [ @@ -19,4 +20,5 @@ 'ImageNet', 'CIFAR10', 'GrowScaleImgDataset', + 'SinGANDataset', ] diff --git a/mmedit/datasets/singan_dataset.py b/mmedit/datasets/singan_dataset.py new file mode 100644 index 0000000000..e03d4567b8 --- /dev/null +++ b/mmedit/datasets/singan_dataset.py @@ -0,0 +1,139 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy + +import mmcv +import numpy as np +from mmengine.dataset import BaseDataset + +from mmedit.registry import DATASETS + + +def create_real_pyramid(real, min_size, max_size, scale_factor_init): + """Create image pyramid. + + This function is modified from the official implementation: + https://github.com/tamarott/SinGAN/blob/master/SinGAN/functions.py#L221 + + In this implementation, we adopt the rescaling function from MMCV. + Args: + real (np.array): The real image array. + min_size (int): The minimum size for the image pyramid. + max_size (int): The maximum size for the image pyramid. + scale_factor_init (float): The initial scale factor. + """ + + num_scales = int( + np.ceil( + np.log(np.power(min_size / min(real.shape[0], real.shape[1]), 1)) / + np.log(scale_factor_init))) + 1 + + scale2stop = int( + np.ceil( + np.log( + min([max_size, max([real.shape[0], real.shape[1]])]) / + max([real.shape[0], real.shape[1]])) / + np.log(scale_factor_init))) + + stop_scale = num_scales - scale2stop + + scale1 = min(max_size / max([real.shape[0], real.shape[1]]), 1) + real_max = mmcv.imrescale(real, scale1) + scale_factor = np.power( + min_size / (min(real_max.shape[0], real_max.shape[1])), + 1 / (stop_scale)) + + scale2stop = int( + np.ceil( + np.log( + min([max_size, max([real.shape[0], real.shape[1]])]) / + max([real.shape[0], real.shape[1]])) / + np.log(scale_factor_init))) + stop_scale = num_scales - scale2stop + + reals = [] + for i in range(stop_scale + 1): + scale = np.power(scale_factor, stop_scale - i) + curr_real = mmcv.imrescale(real, scale) + reals.append(curr_real) + + return reals, scale_factor, stop_scale + + +@DATASETS.register_module() +class SinGANDataset(BaseDataset): + """SinGAN Dataset. + + In this dataset, we create an image pyramid and save it in the cache. + + Args: + img_path (str): Path to the single image file. + min_size (int): Min size of the image pyramid. Here, the number will be + set to the ``min(H, W)``. + max_size (int): Max size of the image pyramid. Here, the number will be + set to the ``max(H, W)``. + scale_factor_init (float): Rescale factor. Note that the actual factor + we use may be a little bit different from this value. + num_samples (int, optional): The number of samples (length) in this + dataset. Defaults to -1. + """ + + def __init__(self, + data_root, + min_size, + max_size, + scale_factor_init, + pipeline, + num_samples=-1): + self.min_size = min_size + self.max_size = max_size + self.scale_factor_init = scale_factor_init + self.num_samples = num_samples + super().__init__(data_root=data_root, pipeline=pipeline) + + def full_init(self): + """Skip the full init process for SinGANDataset.""" + + self.load_data_list(self.min_size, self.max_size, + self.scale_factor_init) + + def load_data_list(self, min_size, max_size, scale_factor_init): + """Load annatations for SinGAN Dataset. + + Args: + min_size (int): The minimum size for the image pyramid. + max_size (int): The maximum size for the image pyramid. + scale_factor_init (float): The initial scale factor. + """ + real = mmcv.imread(self.data_root) + self.reals, self.scale_factor, self.stop_scale = create_real_pyramid( + real, min_size, max_size, scale_factor_init) + + self.data_dict = {} + + for i, real in enumerate(self.reals): + self.data_dict[f'real_scale{i}'] = real + + self.data_dict['input_sample'] = np.zeros_like( + self.data_dict['real_scale0']).astype(np.float32) + + def __getitem__(self, index): + """Get `:attr:self.data_dict`. For SinGAN, we use single image with + different resolution to train the model. + + Args: + idx (int): This will be ignored in `:class:SinGANDataset`. + + Returns: + dict: Dict contains input image in different resolution. + ``self.pipeline``. + """ + return self.pipeline(deepcopy(self.data_dict)) + + def __len__(self): + """Get the length of filtered dataset and automatically call + ``full_init`` if the dataset has not been fully init. + + Returns: + int: The length of filtered dataset. + """ + return int(1e6) if self.num_samples < 0 else self.num_samples diff --git a/mmedit/datasets/transforms/formatting.py b/mmedit/datasets/transforms/formatting.py index f271c31bca..d375b56c47 100644 --- a/mmedit/datasets/transforms/formatting.py +++ b/mmedit/datasets/transforms/formatting.py @@ -76,6 +76,27 @@ def images_to_tensor(value): return tensor +def can_convert_to_image(value): + """Judge whether the input value can be converted to image tensor via + :func:`images_to_tensor` function. + + Args: + value (any): The input value. + + Returns: + bool: If true, the input value can convert to image with + :func:`images_to_tensor`, and vice versa. + """ + if isinstance(value, (List, Tuple)): + return all([can_convert_to_image(v) for v in value]) + elif isinstance(value, np.ndarray): + return True + elif isinstance(value, torch.Tensor): + return True + else: + return False + + @TRANSFORMS.register_module() class PackEditInputs(BaseTransform): """Pack the inputs data for SR, VFI, matting and inpainting. @@ -83,11 +104,17 @@ class PackEditInputs(BaseTransform): Keys for images include ``img``, ``gt``, ``ref``, ``mask``, ``gt_heatmap``, ``trimap``, ``gt_alpha``, ``gt_fg``, ``gt_bg``. All of them will be packed into data field of EditDataSample. + pack_all (bool): Whether pack all variables in `results` to `inputs` dict. + This is useful when keys of the input dict is not fixed. + Please be careful when using this function, because we do not + Defaults to False. Others will be packed into metainfo field of EditDataSample. """ - def __init__(self, keys: Tuple[List[str], str, None] = None): + def __init__(self, + keys: Tuple[List[str], str, None] = None, + pack_all: bool = False): if keys is not None: if isinstance(keys, list): self.keys = keys @@ -95,6 +122,7 @@ def __init__(self, keys: Tuple[List[str], str, None] = None): self.keys = [keys] else: self.keys = None + self.pack_all = pack_all def transform(self, results: dict) -> dict: """Method to pack the input data. @@ -113,14 +141,14 @@ def transform(self, results: dict) -> dict: packed_results = dict() data_sample = EditDataSample() - if self.keys is not None: + pack_keys = [k for k in results.keys()] if self.pack_all else self.keys + if pack_keys is not None: packed_results['inputs'] = dict() - for key in self.keys: - img = results.pop(key) - if len(img.shape) < 3: - img = np.expand_dims(img, -1) - img = np.ascontiguousarray(img.transpose(2, 0, 1)) - packed_results['inputs'][key] = to_tensor(img) + for key in pack_keys: + val = results[key] + if can_convert_to_image(val): + packed_results['inputs'][key] = images_to_tensor(val) + results.pop(key) elif 'img' in results: img = results.pop('img') diff --git a/tests/test_datasets/test_singan_dataset.py b/tests/test_datasets/test_singan_dataset.py new file mode 100644 index 0000000000..a872d4242f --- /dev/null +++ b/tests/test_datasets/test_singan_dataset.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +from mmedit.datasets import SinGANDataset +from mmedit.utils import register_all_modules + +register_all_modules() + + +class TestSinGANDataset(object): + + @classmethod + def setup_class(cls): + cls.imgs_root = osp.join( + osp.dirname(osp.dirname(__file__)), 'data/image/gt/baboon.png') + cls.min_size = 25 + cls.max_size = 250 + cls.scale_factor_init = 0.75 + cls.pipeline = [dict(type='PackEditInputs', pack_all=True)] + + def test_singan_dataset(self): + dataset = SinGANDataset( + self.imgs_root, + min_size=self.min_size, + max_size=self.max_size, + scale_factor_init=self.scale_factor_init, + pipeline=self.pipeline) + assert len(dataset) == 1000000 + + data_dict = dataset[0]['inputs'] + assert all([f'real_scale{i}' in data_dict for i in range(10)]) diff --git a/tests/test_datasets/test_transforms/test_formatting.py b/tests/test_datasets/test_transforms/test_formatting.py index 55b633e2e8..c9ac37ddf1 100644 --- a/tests/test_datasets/test_transforms/test_formatting.py +++ b/tests/test_datasets/test_transforms/test_formatting.py @@ -4,7 +4,8 @@ from mmcv.transforms import to_tensor from mmedit.datasets.transforms import PackEditInputs, ToTensor -from mmedit.datasets.transforms.formatting import images_to_tensor +from mmedit.datasets.transforms.formatting import (can_convert_to_image, + images_to_tensor) from mmedit.structures.edit_data_sample import EditDataSample @@ -117,6 +118,18 @@ def test_pack_edit_inputs(): assert data_sample.metainfo['img_shape'] == (64, 64) assert data_sample.metainfo['a'] == 'b' + # test pack_all + pack_edit_inputs = PackEditInputs(pack_all=True) + results = ori_results.copy() + packed_results = pack_edit_inputs(results) + print(packed_results['inputs'].keys()) + + target_keys = [ + 'img', 'gt', 'img_lq', 'ref', 'ref_lq', 'mask', 'gt_heatmap', + 'gt_unsharp', 'merged', 'trimap', 'alpha', 'fg', 'bg' + ] + assert all([k in target_keys for k in packed_results['inputs']]) + def test_to_tensor(): @@ -135,3 +148,14 @@ def test_to_tensor(): assert set(keys).issubset(results.keys()) for _, v in results.items(): assert isinstance(v, torch.Tensor) + + +def test_can_convert_to_image(): + values = [ + np.random.rand(64, 64, 3), + [np.random.rand(64, 61, 3), + np.random.rand(64, 61, 3)], (64, 64), 'b' + ] + targets = [True, True, False, False] + for val, tar in zip(values, targets): + assert can_convert_to_image(val) == tar From 237b9a8c6ec29b9417c38c9e60ebab2a72015a72 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Wed, 26 Oct 2022 22:18:01 +0800 Subject: [PATCH 2/7] adopt singan config with PackEditInputs --- configs/singan/singan_fish.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/singan/singan_fish.py b/configs/singan/singan_fish.py index df2f946408..1ffbf734eb 100644 --- a/configs/singan/singan_fish.py +++ b/configs/singan/singan_fish.py @@ -41,7 +41,7 @@ dataset_type = 'SinGANDataset' data_root = './data/singan/fish-crop.jpg' -pipeline = [dict(type='PackEditInputs')] +pipeline = [dict(type='PackEditInputs', pack_all=True)] dataset = dict( type=dataset_type, data_root=data_root, From 7985126143a7746b8038d9ba8e5ce6da9812f820 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Wed, 26 Oct 2022 22:25:24 +0800 Subject: [PATCH 3/7] revise forward logic of SinGANMSGeneratorPE --- mmedit/models/editors/mspie/pe_singan_generator.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mmedit/models/editors/mspie/pe_singan_generator.py b/mmedit/models/editors/mspie/pe_singan_generator.py index f40c74e6f0..d4cd33bdf2 100644 --- a/mmedit/models/editors/mspie/pe_singan_generator.py +++ b/mmedit/models/editors/mspie/pe_singan_generator.py @@ -167,9 +167,12 @@ def forward(self, noise_list = [] if input_sample is None: + h, w = fixed_noises[0].shape[-2:] + if self.noise_with_pad: + h -= 2 * self.pad_head + w -= 2 * self.pad_head input_sample = torch.zeros( - (num_batches, 3, fixed_noises[0].shape[-2], - fixed_noises[0].shape[-1])).to(fixed_noises[0]) + (num_batches, 3, h, w)).to(fixed_noises[0]) g_res = input_sample From fda531500ac130cb9e5904db62948ba3555567ad Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Wed, 26 Oct 2022 22:26:26 +0800 Subject: [PATCH 4/7] fix ema-related logic of SinGAN --- mmedit/models/editors/singan/singan.py | 44 ++++++++++++++++---------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/mmedit/models/editors/singan/singan.py b/mmedit/models/editors/singan/singan.py index 391f55c460..3c640c5bc0 100644 --- a/mmedit/models/editors/singan/singan.py +++ b/mmedit/models/editors/singan/singan.py @@ -16,7 +16,7 @@ from mmedit.models.utils import get_module_device from mmedit.registry import MODELS from mmedit.structures import EditDataSample, PixelData -from mmedit.utils import SampleList +from mmedit.utils import ForwardInputs, SampleList from ...base_models import BaseGAN from ...utils import set_requires_grad @@ -171,27 +171,27 @@ def construct_fixed_noises(self): self.fixed_noises.append(noise) def forward(self, - batch_inputs: dict, + inputs: ForwardInputs, data_samples: Optional[list] = None, mode=None) -> List[EditDataSample]: - """Forward function for SinGAN. For SinGAN, `batch_inputs` should be a - dict contains 'num_batches', 'mode' and other input arguments for the + """Forward function for SinGAN. For SinGAN, `inputs` should be a dict + contains 'num_batches', 'mode' and other input arguments for the generator. Args: - batch_inputs (dict): Dict containing the necessary information + inputs (dict): Dict containing the necessary information (e.g., noise, num_batches, mode) to generate image. data_samples (Optional[list]): Data samples collated by :attr:`data_preprocessor`. Defaults to None. mode (Optional[str]): `mode` is not used in :class:`BaseConditionalGAN`. Defaults to None. """ - sample_model = self._get_valid_model(batch_inputs) + sample_model = self._get_valid_model(inputs) # handle batch_inputs - assert isinstance(batch_inputs, dict), ( - 'SinGAN only support dict type batch_inputs in forward function.') - gen_kwargs = deepcopy(batch_inputs) + assert isinstance(inputs, dict), ( + 'SinGAN only support dict type inputs in forward function.') + gen_kwargs = deepcopy(inputs) num_batches = gen_kwargs.pop('num_batches', 1) assert num_batches == 1, ( 'SinGAN only support \'num_batches\' as 1, but receive ' @@ -235,14 +235,24 @@ def forward(self, gen_sample = EditDataSample() if data_samples: gen_sample.update(data_samples[idx]) - if isinstance(outputs, dict): - gen_sample.ema = EditDataSample( - fake_img=PixelData(data=outputs['ema'][idx]), - sample_model='ema') - gen_sample.orig = EditDataSample( - fake_img=PixelData(data=outputs['orig'][idx]), - sample_model='orig') - gen_sample.sample_model = 'ema/orig' + if sample_model == 'ema/orig': + for model_ in ['ema', 'orig']: + model_sample_ = EditDataSample() + fake_img = PixelData(data=outputs[model_]['fake_img'][idx]) + prev_res_list = [ + r[idx] for r in outputs[model_]['prev_res_list'] + ] + model_sample_.fake_img = fake_img + model_sample_.prev_res_list = prev_res_list + model_sample_.sample_model = sample_model + + gen_sample.set_field(model_, model_sample_) + elif isinstance(outputs, dict): + gen_sample.fake_img = PixelData(data=outputs['fake_img'][idx]) + gen_sample.prev_res_list = [ + r[idx] for r in outputs['prev_res_list'] + ] + gen_sample.sample_model = sample_model else: gen_sample.fake_img = PixelData(data=outputs[idx]) gen_sample.sample_model = sample_model From 488dae15f7505ed7ac85b6fba2ae63ac1b6cde04 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Wed, 26 Oct 2022 22:26:50 +0800 Subject: [PATCH 5/7] add singan demo --- demo/singan_demo.py | 113 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 demo/singan_demo.py diff --git a/demo/singan_demo.py b/demo/singan_demo.py new file mode 100644 index 0000000000..9633adfe1a --- /dev/null +++ b/demo/singan_demo.py @@ -0,0 +1,113 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import sys + +import mmcv +import torch +from mmengine import Config, print_log +from mmengine.logging import MMLogger +from mmengine.runner import load_checkpoint, set_random_seed + +# yapf: disable +sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa + +from mmedit.engine import * # isort:skip # noqa: F401,F403,E402 +from mmedit.datasets import * # isort:skip # noqa: F401,F403,E402 +from mmedit.models import * # isort:skip # noqa: F401,F403,E402 + +from mmedit.registry import MODELS # isort:skip # noqa + +# yapf: enable + + +def parse_args(): + parser = argparse.ArgumentParser(description='Evaluate a GAN model') + parser.add_argument('config', help='evaluation config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument('--seed', type=int, default=2021, help='random seed') + parser.add_argument( + '--deterministic', + action='store_true', + help='whether to set deterministic options for CUDNN backend.') + parser.add_argument( + '--samples-path', + type=str, + default='./', + help='path to store images. If not given, remove it after evaluation\ + finished') + parser.add_argument( + '--save-prev-res', + action='store_true', + help='whether to store the results from previous stages') + parser.add_argument( + '--num-samples', + type=int, + default=10, + help='the number of synthesized samples') + args = parser.parse_args() + return args + + +def _tensor2img(img): + img = img.permute(1, 2, 0) + img = ((img + 1) / 2 * 255).clamp(0, 255).to(torch.uint8) + + return img.cpu().numpy() + + +@torch.no_grad() +def main(): + MMLogger.get_instance('mmedit') + + args = parse_args() + cfg = Config.fromfile(args.config) + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + + # set random seeds + if args.seed is not None: + set_random_seed(args.seed, deterministic=args.deterministic) + + # set scope manually + cfg.model['_scope_'] = 'mmedit' + # build the model and load checkpoint + model = MODELS.build(cfg.model) + + model.eval() + + # load ckpt + print_log(f'Loading ckpt from {args.checkpoint}') + _ = load_checkpoint(model, args.checkpoint, map_location='cpu') + + # add dp wrapper + if torch.cuda.is_available(): + model = model.cuda() + + for sample_iter in range(args.num_samples): + outputs = model.test_step( + dict(inputs=dict(num_batches=1, get_prev_res=args.save_prev_res))) + + # store results from previous stages + if args.save_prev_res: + fake_img = outputs[0].fake_img.data + prev_res_list = outputs[0].prev_res_list + prev_res_list.append(fake_img) + for i, img in enumerate(prev_res_list): + img = _tensor2img(img) + mmcv.imwrite( + img, + os.path.join(args.samples_path, f'stage{i}', + f'rand_sample_{sample_iter}.png')) + # just store the final result + else: + img = _tensor2img(outputs[0].fake_img.data) + mmcv.imwrite( + img, + os.path.join(args.samples_path, + f'rand_sample_{sample_iter}.png')) + + +if __name__ == '__main__': + main() From 28682fb6be035fcd8e28cb6e1b1b331d8a8244ad Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Thu, 27 Oct 2022 21:00:05 +0800 Subject: [PATCH 6/7] add unit test for pe-singan --- .../test_mspie/test_pe_singan_generator.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_models/test_editors/test_mspie/test_pe_singan_generator.py b/tests/test_models/test_editors/test_mspie/test_pe_singan_generator.py index 6e14b7b5af..144baaf8fe 100644 --- a/tests/test_models/test_editors/test_mspie/test_pe_singan_generator.py +++ b/tests/test_models/test_editors/test_mspie/test_pe_singan_generator.py @@ -83,3 +83,13 @@ def test_singan_gen_pe(self): res = gen(self.input_sample, self.fixed_noises, self.noise_weights, 'rand', 2) assert res.shape == (1, 3, 12, 12) + + gen = SinGANMSGeneratorPE( + interp_pad=True, noise_with_pad=True, **self.default_args) + res = gen(None, self.fixed_noises, self.noise_weights, 'rand', 2) + assert res.shape == (1, 3, 6, 6) + + gen = SinGANMSGeneratorPE( + interp_pad=True, noise_with_pad=False, **self.default_args) + res = gen(None, self.fixed_noises, self.noise_weights, 'rand', 2) + assert res.shape == (1, 3, 12, 12) From 3910edf15d5dd7d80ab45df76463664223361060 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Thu, 27 Oct 2022 21:53:18 +0800 Subject: [PATCH 7/7] add unit test for singan-ema and revise singan --- mmedit/models/editors/singan/singan.py | 21 +++++--- .../test_editors/test_singan/test_singan.py | 50 +++++++++++++++++++ 2 files changed, 63 insertions(+), 8 deletions(-) diff --git a/mmedit/models/editors/singan/singan.py b/mmedit/models/editors/singan/singan.py index 3c640c5bc0..53e39c0834 100644 --- a/mmedit/models/editors/singan/singan.py +++ b/mmedit/models/editors/singan/singan.py @@ -186,7 +186,6 @@ def forward(self, mode (Optional[str]): `mode` is not used in :class:`BaseConditionalGAN`. Defaults to None. """ - sample_model = self._get_valid_model(inputs) # handle batch_inputs assert isinstance(inputs, dict), ( @@ -196,6 +195,9 @@ def forward(self, assert num_batches == 1, ( 'SinGAN only support \'num_batches\' as 1, but receive ' f'{num_batches}.') + sample_model = self._get_valid_model(inputs) + gen_kwargs.pop('sample_model', None) # remove sample_model + mode = gen_kwargs.pop('mode', mode) mode = 'rand' if mode is None else mode curr_scale = gen_kwargs.pop('curr_scale', self.curr_stage) @@ -238,15 +240,18 @@ def forward(self, if sample_model == 'ema/orig': for model_ in ['ema', 'orig']: model_sample_ = EditDataSample() - fake_img = PixelData(data=outputs[model_]['fake_img'][idx]) - prev_res_list = [ - r[idx] for r in outputs[model_]['prev_res_list'] - ] + output_ = outputs[model_] + if isinstance(output_, dict): + fake_img = PixelData(data=output_['fake_img'][idx]) + prev_res_list = [ + r[idx] for r in outputs[model_]['prev_res_list'] + ] + model_sample_.prev_res_list = prev_res_list + else: + fake_img = PixelData(data=output_[idx]) model_sample_.fake_img = fake_img - model_sample_.prev_res_list = prev_res_list model_sample_.sample_model = sample_model - - gen_sample.set_field(model_, model_sample_) + gen_sample.set_field(model_sample_, model_) elif isinstance(outputs, dict): gen_sample.fake_img = PixelData(data=outputs['fake_img'][idx]) gen_sample.prev_res_list = [ diff --git a/tests/test_models/test_editors/test_singan/test_singan.py b/tests/test_models/test_editors/test_singan/test_singan.py index a007d4f9e0..4b0bf73030 100644 --- a/tests/test_models/test_editors/test_singan/test_singan.py +++ b/tests/test_models/test_editors/test_singan/test_singan.py @@ -75,3 +75,53 @@ def test_singan_cpu(self): elif i in [4, 5]: assert singan.curr_stage == 2 assert img.shape[-2:] == (32, 32) + + outputs = singan.forward( + dict(num_batches=1, get_prev_res=True), None) + assert all([hasattr(out, 'prev_res_list') for out in outputs]) + + # test forward singan with ema + singan = SinGAN( + self.generator, + self.disc, + num_scales=3, + data_preprocessor=self.data_preprocessor, + noise_weight_init=self.noise_weight_init, + iters_per_scale=self.iters_per_scale, + lr_scheduler_args=self.lr_scheduler_args, + ema_confg=dict(type='ExponentialMovingAverage')) + optim_wrapper_dict_builder = SinGANOptimWrapperConstructor( + self.optim_wrapper_cfg) + optim_wrapper_dict = optim_wrapper_dict_builder(singan) + + for i in range(6): + singan.train_step(self.data_batch, optim_wrapper_dict) + message_hub.update_info('iter', message_hub.get_info('iter') + 1) + + outputs = singan.forward( + dict(num_batches=1, sample_model='ema/orig'), None) + + img = torch.stack([out.orig.fake_img.data for out in outputs], + dim=0) + img_ema = torch.stack([out.ema.fake_img.data for out in outputs], + dim=0) + if i in [0, 1]: + assert singan.curr_stage == 0 + assert img.shape[-2:] == (25, 25) + assert img_ema.shape[-2:] == (25, 25) + elif i in [2, 3]: + assert singan.curr_stage == 1 + assert img.shape[-2:] == (30, 30) + assert img_ema.shape[-2:] == (30, 30) + elif i in [4, 5]: + assert singan.curr_stage == 2 + assert img.shape[-2:] == (32, 32) + assert img_ema.shape[-2:] == (32, 32) + + outputs = singan.forward( + dict( + num_batches=1, sample_model='ema/orig', get_prev_res=True), + None) + + assert all([hasattr(out.orig, 'prev_res_list') for out in outputs]) + assert all([hasattr(out.ema, 'prev_res_list') for out in outputs])