Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Fix PESinGAN-inter-pad setting + add SinGAN Dataset + add SinGAN demo #1363

Merged
merged 7 commits into from
Oct 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/singan/singan_fish.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
dataset_type = 'SinGANDataset'
data_root = './data/singan/fish-crop.jpg'

pipeline = [dict(type='PackEditInputs')]
pipeline = [dict(type='PackEditInputs', pack_all=True)]
dataset = dict(
type=dataset_type,
data_root=data_root,
Expand Down
113 changes: 113 additions & 0 deletions demo/singan_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import sys

import mmcv
import torch
from mmengine import Config, print_log
from mmengine.logging import MMLogger
from mmengine.runner import load_checkpoint, set_random_seed

# yapf: disable
sys.path.append(os.path.abspath(os.path.join(__file__, '../..'))) # isort:skip # noqa

from mmedit.engine import * # isort:skip # noqa: F401,F403,E402
from mmedit.datasets import * # isort:skip # noqa: F401,F403,E402
from mmedit.models import * # isort:skip # noqa: F401,F403,E402

from mmedit.registry import MODELS # isort:skip # noqa

# yapf: enable


def parse_args():
parser = argparse.ArgumentParser(description='Evaluate a GAN model')
parser.add_argument('config', help='evaluation config file path')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--seed', type=int, default=2021, help='random seed')
parser.add_argument(
'--deterministic',
action='store_true',
help='whether to set deterministic options for CUDNN backend.')
parser.add_argument(
'--samples-path',
type=str,
default='./',
help='path to store images. If not given, remove it after evaluation\
finished')
parser.add_argument(
'--save-prev-res',
action='store_true',
help='whether to store the results from previous stages')
parser.add_argument(
'--num-samples',
type=int,
default=10,
help='the number of synthesized samples')
args = parser.parse_args()
return args


def _tensor2img(img):
img = img.permute(1, 2, 0)
img = ((img + 1) / 2 * 255).clamp(0, 255).to(torch.uint8)

return img.cpu().numpy()


@torch.no_grad()
def main():
MMLogger.get_instance('mmedit')

args = parse_args()
cfg = Config.fromfile(args.config)
# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True

# set random seeds
if args.seed is not None:
set_random_seed(args.seed, deterministic=args.deterministic)

# set scope manually
cfg.model['_scope_'] = 'mmedit'
# build the model and load checkpoint
model = MODELS.build(cfg.model)

model.eval()

# load ckpt
print_log(f'Loading ckpt from {args.checkpoint}')
_ = load_checkpoint(model, args.checkpoint, map_location='cpu')

# add dp wrapper
if torch.cuda.is_available():
model = model.cuda()

for sample_iter in range(args.num_samples):
outputs = model.test_step(
dict(inputs=dict(num_batches=1, get_prev_res=args.save_prev_res)))

# store results from previous stages
if args.save_prev_res:
fake_img = outputs[0].fake_img.data
prev_res_list = outputs[0].prev_res_list
prev_res_list.append(fake_img)
for i, img in enumerate(prev_res_list):
img = _tensor2img(img)
mmcv.imwrite(
img,
os.path.join(args.samples_path, f'stage{i}',
f'rand_sample_{sample_iter}.png'))
# just store the final result
else:
img = _tensor2img(outputs[0].fake_img.data)
mmcv.imwrite(
img,
os.path.join(args.samples_path,
f'rand_sample_{sample_iter}.png'))


if __name__ == '__main__':
main()
2 changes: 2 additions & 0 deletions mmedit/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .grow_scale_image_dataset import GrowScaleImgDataset
from .imagenet_dataset import ImageNet
from .paired_image_dataset import PairedImageDataset
from .singan_dataset import SinGANDataset
from .unpaired_image_dataset import UnpairedImageDataset

__all__ = [
Expand All @@ -19,4 +20,5 @@
'ImageNet',
'CIFAR10',
'GrowScaleImgDataset',
'SinGANDataset',
]
139 changes: 139 additions & 0 deletions mmedit/datasets/singan_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy

import mmcv
import numpy as np
from mmengine.dataset import BaseDataset

from mmedit.registry import DATASETS


def create_real_pyramid(real, min_size, max_size, scale_factor_init):
"""Create image pyramid.

This function is modified from the official implementation:
https://github.com/tamarott/SinGAN/blob/master/SinGAN/functions.py#L221

In this implementation, we adopt the rescaling function from MMCV.
Args:
real (np.array): The real image array.
min_size (int): The minimum size for the image pyramid.
max_size (int): The maximum size for the image pyramid.
scale_factor_init (float): The initial scale factor.
"""

num_scales = int(
np.ceil(
np.log(np.power(min_size / min(real.shape[0], real.shape[1]), 1)) /
np.log(scale_factor_init))) + 1

scale2stop = int(
np.ceil(
np.log(
min([max_size, max([real.shape[0], real.shape[1]])]) /
max([real.shape[0], real.shape[1]])) /
np.log(scale_factor_init)))

stop_scale = num_scales - scale2stop

scale1 = min(max_size / max([real.shape[0], real.shape[1]]), 1)
real_max = mmcv.imrescale(real, scale1)
scale_factor = np.power(
min_size / (min(real_max.shape[0], real_max.shape[1])),
1 / (stop_scale))

scale2stop = int(
np.ceil(
np.log(
min([max_size, max([real.shape[0], real.shape[1]])]) /
max([real.shape[0], real.shape[1]])) /
np.log(scale_factor_init)))
stop_scale = num_scales - scale2stop

reals = []
for i in range(stop_scale + 1):
scale = np.power(scale_factor, stop_scale - i)
curr_real = mmcv.imrescale(real, scale)
reals.append(curr_real)

return reals, scale_factor, stop_scale


@DATASETS.register_module()
class SinGANDataset(BaseDataset):
"""SinGAN Dataset.

In this dataset, we create an image pyramid and save it in the cache.

Args:
img_path (str): Path to the single image file.
min_size (int): Min size of the image pyramid. Here, the number will be
set to the ``min(H, W)``.
max_size (int): Max size of the image pyramid. Here, the number will be
set to the ``max(H, W)``.
scale_factor_init (float): Rescale factor. Note that the actual factor
we use may be a little bit different from this value.
num_samples (int, optional): The number of samples (length) in this
dataset. Defaults to -1.
"""

def __init__(self,
data_root,
min_size,
max_size,
scale_factor_init,
pipeline,
num_samples=-1):
self.min_size = min_size
self.max_size = max_size
self.scale_factor_init = scale_factor_init
self.num_samples = num_samples
super().__init__(data_root=data_root, pipeline=pipeline)

def full_init(self):
"""Skip the full init process for SinGANDataset."""

self.load_data_list(self.min_size, self.max_size,
self.scale_factor_init)

def load_data_list(self, min_size, max_size, scale_factor_init):
"""Load annatations for SinGAN Dataset.

Args:
min_size (int): The minimum size for the image pyramid.
max_size (int): The maximum size for the image pyramid.
scale_factor_init (float): The initial scale factor.
"""
real = mmcv.imread(self.data_root)
self.reals, self.scale_factor, self.stop_scale = create_real_pyramid(
real, min_size, max_size, scale_factor_init)

self.data_dict = {}

for i, real in enumerate(self.reals):
self.data_dict[f'real_scale{i}'] = real

self.data_dict['input_sample'] = np.zeros_like(
self.data_dict['real_scale0']).astype(np.float32)

def __getitem__(self, index):
"""Get `:attr:self.data_dict`. For SinGAN, we use single image with
different resolution to train the model.

Args:
idx (int): This will be ignored in `:class:SinGANDataset`.

Returns:
dict: Dict contains input image in different resolution.
``self.pipeline``.
"""
return self.pipeline(deepcopy(self.data_dict))

def __len__(self):
"""Get the length of filtered dataset and automatically call
``full_init`` if the dataset has not been fully init.

Returns:
int: The length of filtered dataset.
"""
return int(1e6) if self.num_samples < 0 else self.num_samples
44 changes: 36 additions & 8 deletions mmedit/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,25 +76,53 @@ def images_to_tensor(value):
return tensor


def can_convert_to_image(value):
"""Judge whether the input value can be converted to image tensor via
:func:`images_to_tensor` function.

Args:
value (any): The input value.

Returns:
bool: If true, the input value can convert to image with
:func:`images_to_tensor`, and vice versa.
"""
if isinstance(value, (List, Tuple)):
return all([can_convert_to_image(v) for v in value])
elif isinstance(value, np.ndarray):
return True
elif isinstance(value, torch.Tensor):
return True
else:
return False


@TRANSFORMS.register_module()
class PackEditInputs(BaseTransform):
"""Pack the inputs data for SR, VFI, matting and inpainting.

Keys for images include ``img``, ``gt``, ``ref``, ``mask``, ``gt_heatmap``,
``trimap``, ``gt_alpha``, ``gt_fg``, ``gt_bg``. All of them will be
packed into data field of EditDataSample.
pack_all (bool): Whether pack all variables in `results` to `inputs` dict.
This is useful when keys of the input dict is not fixed.
Please be careful when using this function, because we do not
Defaults to False.

Others will be packed into metainfo field of EditDataSample.
"""

def __init__(self, keys: Tuple[List[str], str, None] = None):
def __init__(self,
keys: Tuple[List[str], str, None] = None,
pack_all: bool = False):
if keys is not None:
if isinstance(keys, list):
self.keys = keys
else:
self.keys = [keys]
else:
self.keys = None
self.pack_all = pack_all

def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Expand All @@ -113,14 +141,14 @@ def transform(self, results: dict) -> dict:
packed_results = dict()
data_sample = EditDataSample()

if self.keys is not None:
pack_keys = [k for k in results.keys()] if self.pack_all else self.keys
if pack_keys is not None:
packed_results['inputs'] = dict()
for key in self.keys:
img = results.pop(key)
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
packed_results['inputs'][key] = to_tensor(img)
for key in pack_keys:
val = results[key]
if can_convert_to_image(val):
packed_results['inputs'][key] = images_to_tensor(val)
results.pop(key)

elif 'img' in results:
img = results.pop('img')
Expand Down
7 changes: 5 additions & 2 deletions mmedit/models/editors/mspie/pe_singan_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,12 @@ def forward(self,
noise_list = []

if input_sample is None:
h, w = fixed_noises[0].shape[-2:]
if self.noise_with_pad:
h -= 2 * self.pad_head
w -= 2 * self.pad_head
input_sample = torch.zeros(
(num_batches, 3, fixed_noises[0].shape[-2],
fixed_noises[0].shape[-1])).to(fixed_noises[0])
(num_batches, 3, h, w)).to(fixed_noises[0])

g_res = input_sample

Expand Down
Loading