Skip to content

Commit

Permalink
[Config] Update metric config in ggan (#1386)
Browse files Browse the repository at this point in the history
* update metric config in ggan

* update gen_default_runtime
  • Loading branch information
LeoXing1996 committed Oct 31, 2022
1 parent 9c0768f commit 71961a7
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 27 deletions.
6 changes: 3 additions & 3 deletions configs/_base_/datasets/paired_imgs_256x256_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@

# `batch_size` and `data_root` need to be set.
train_dataloader = dict(
batch_size=1,
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
Expand All @@ -79,7 +79,7 @@
pipeline=train_pipeline))

val_dataloader = dict(
batch_size=1,
batch_size=4,
num_workers=4,
dataset=dict(
type=dataset_type,
Expand All @@ -89,7 +89,7 @@
persistent_workers=True)

test_dataloader = dict(
batch_size=1,
batch_size=4,
num_workers=4,
dataset=dict(
type=dataset_type,
Expand Down
6 changes: 3 additions & 3 deletions configs/_base_/datasets/unpaired_imgs_256x256.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

# `batch_size` and `data_root` need to be set.
train_dataloader = dict(
batch_size=1,
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
Expand All @@ -74,7 +74,7 @@
pipeline=train_pipeline))

val_dataloader = dict(
batch_size=1,
batch_size=4,
num_workers=4,
dataset=dict(
type=dataset_type,
Expand All @@ -85,7 +85,7 @@
persistent_workers=True)

test_dataloader = dict(
batch_size=1,
batch_size=4,
num_workers=4,
dataset=dict(
type=dataset_type,
Expand Down
5 changes: 3 additions & 2 deletions configs/_base_/gen_default_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
type='CheckpointHook',
interval=10000,
by_epoch=False,
less_keys=['FID-Full-50k/fid'],
greater_keys=['IS-50k/is'],
max_keep_ckpts=20,
less_keys=['FID-Full-50k/fid', 'swd/avg'],
greater_keys=['IS-50k/is', 'ms-ssim/avg'],
save_optimizer=True))

# config for environment
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/matting_default_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
fn_key='trimap_path',
img_keys=['pred_alpha', 'trimap', 'gt_merged', 'gt_alpha'],
bgr2rgb=True)
custom_hooks = [dict(type='BasicVisualizationHook', interval=1)]
custom_hooks = [dict(type='BasicVisualizationHook', interval=2000)]

log_level = 'INFO'
log_processor = dict(type='LogProcessor', by_epoch=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@
]
default_hooks = dict(
checkpoint=dict(
max_keep_ckpts=20,
save_best=['FID-Full-50k/fid', 'swd/avg', 'ms-ssim/avg'],
rule=['less', 'less', 'greater']))
max_keep_ckpts=20, save_best='FID-Full-50k/fid', rule='less'))

# METRICS
metrics = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
discriminator=dict(
optimizer=dict(type='Adam', lr=0.0001, betas=(0.5, 0.99))))

train_cfg = dict(max_iters=160000)

default_hooks = dict(
checkpoint=dict(
max_keep_ckpts=20,
save_best=['FID-Full-50k/fid', 'swd/avg', 'ms-ssim/avg'],
rule=['less', 'less', 'greater']))
max_keep_ckpts=20, save_best='FID-Full-50k/fid', rule='less'))

# VIS_HOOK
custom_hooks = [
Expand All @@ -37,8 +37,6 @@
vis_kwargs_list=dict(type='GAN', name='fake_img'))
]

train_cfg = dict(max_iters=160000)

# METRICS
metrics = [
dict(
Expand All @@ -58,14 +56,5 @@
image_shape=(3, 128, 128))
]

val_metrics = [
dict(
type='FrechetInceptionDistance',
prefix='FID-Full-50k',
fake_nums=50000,
inception_style='StyleGAN',
sample_model='orig'),
]

val_evaluator = dict(metrics=val_metrics)
val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)

0 comments on commit 71961a7

Please sign in to comment.