diff --git a/configs/_base_/datasets/paired_imgs_256x256_crop.py b/configs/_base_/datasets/paired_imgs_256x256_crop.py index ec4dc4f4e0..6770420cf5 100644 --- a/configs/_base_/datasets/paired_imgs_256x256_crop.py +++ b/configs/_base_/datasets/paired_imgs_256x256_crop.py @@ -69,7 +69,7 @@ # `batch_size` and `data_root` need to be set. train_dataloader = dict( - batch_size=1, + batch_size=4, num_workers=4, persistent_workers=True, sampler=dict(type='InfiniteSampler', shuffle=True), @@ -79,7 +79,7 @@ pipeline=train_pipeline)) val_dataloader = dict( - batch_size=1, + batch_size=4, num_workers=4, dataset=dict( type=dataset_type, @@ -89,7 +89,7 @@ persistent_workers=True) test_dataloader = dict( - batch_size=1, + batch_size=4, num_workers=4, dataset=dict( type=dataset_type, diff --git a/configs/_base_/datasets/unpaired_imgs_256x256.py b/configs/_base_/datasets/unpaired_imgs_256x256.py index 42be934f3e..1b55ccc51d 100644 --- a/configs/_base_/datasets/unpaired_imgs_256x256.py +++ b/configs/_base_/datasets/unpaired_imgs_256x256.py @@ -64,7 +64,7 @@ # `batch_size` and `data_root` need to be set. train_dataloader = dict( - batch_size=1, + batch_size=4, num_workers=4, persistent_workers=True, sampler=dict(type='InfiniteSampler', shuffle=True), @@ -74,7 +74,7 @@ pipeline=train_pipeline)) val_dataloader = dict( - batch_size=1, + batch_size=4, num_workers=4, dataset=dict( type=dataset_type, @@ -85,7 +85,7 @@ persistent_workers=True) test_dataloader = dict( - batch_size=1, + batch_size=4, num_workers=4, dataset=dict( type=dataset_type, diff --git a/configs/_base_/gen_default_runtime.py b/configs/_base_/gen_default_runtime.py index a13f62f701..64a510e42e 100644 --- a/configs/_base_/gen_default_runtime.py +++ b/configs/_base_/gen_default_runtime.py @@ -19,8 +19,9 @@ type='CheckpointHook', interval=10000, by_epoch=False, - less_keys=['FID-Full-50k/fid'], - greater_keys=['IS-50k/is'], + max_keep_ckpts=20, + less_keys=['FID-Full-50k/fid', 'swd/avg'], + greater_keys=['IS-50k/is', 'ms-ssim/avg'], save_optimizer=True)) # config for environment diff --git a/configs/_base_/matting_default_runtime.py b/configs/_base_/matting_default_runtime.py index 9d9f09cd6c..e544dd8677 100644 --- a/configs/_base_/matting_default_runtime.py +++ b/configs/_base_/matting_default_runtime.py @@ -26,7 +26,7 @@ fn_key='trimap_path', img_keys=['pred_alpha', 'trimap', 'gt_merged', 'gt_alpha'], bgr2rgb=True) -custom_hooks = [dict(type='BasicVisualizationHook', interval=1)] +custom_hooks = [dict(type='BasicVisualizationHook', interval=2000)] log_level = 'INFO' log_processor = dict(type='LogProcessor', by_epoch=False) diff --git a/configs/ggan/ggan_dcgan-archi_lr1e-3-1xb128-12Mimgs_celeba-cropped-64x64.py b/configs/ggan/ggan_dcgan-archi_lr1e-3-1xb128-12Mimgs_celeba-cropped-64x64.py index a3c56f1355..a42716a2d9 100644 --- a/configs/ggan/ggan_dcgan-archi_lr1e-3-1xb128-12Mimgs_celeba-cropped-64x64.py +++ b/configs/ggan/ggan_dcgan-archi_lr1e-3-1xb128-12Mimgs_celeba-cropped-64x64.py @@ -32,9 +32,7 @@ ] default_hooks = dict( checkpoint=dict( - max_keep_ckpts=20, - save_best=['FID-Full-50k/fid', 'swd/avg', 'ms-ssim/avg'], - rule=['less', 'less', 'greater'])) + max_keep_ckpts=20, save_best='FID-Full-50k/fid', rule='less')) # METRICS metrics = [ diff --git a/configs/ggan/ggan_dcgan-archi_lr1e-4-1xb64-10Mimgs_celeba-cropped-128x128.py b/configs/ggan/ggan_dcgan-archi_lr1e-4-1xb64-10Mimgs_celeba-cropped-128x128.py index 3d2052ec01..4173260d8b 100644 --- a/configs/ggan/ggan_dcgan-archi_lr1e-4-1xb64-10Mimgs_celeba-cropped-128x128.py +++ b/configs/ggan/ggan_dcgan-archi_lr1e-4-1xb64-10Mimgs_celeba-cropped-128x128.py @@ -22,11 +22,11 @@ discriminator=dict( optimizer=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99)))) +train_cfg = dict(max_iters=160000) + default_hooks = dict( checkpoint=dict( - max_keep_ckpts=20, - save_best=['FID-Full-50k/fid', 'swd/avg', 'ms-ssim/avg'], - rule=['less', 'less', 'greater'])) + max_keep_ckpts=20, save_best='FID-Full-50k/fid', rule='less')) # VIS_HOOK custom_hooks = [ @@ -37,8 +37,6 @@ vis_kwargs_list=dict(type='GAN', name='fake_img')) ] -train_cfg = dict(max_iters=160000) - # METRICS metrics = [ dict( @@ -58,14 +56,5 @@ image_shape=(3, 128, 128)) ] -val_metrics = [ - dict( - type='FrechetInceptionDistance', - prefix='FID-Full-50k', - fake_nums=50000, - inception_style='StyleGAN', - sample_model='orig'), -] - -val_evaluator = dict(metrics=val_metrics) +val_evaluator = dict(metrics=metrics) test_evaluator = dict(metrics=metrics)