forked from open-mmlab/mmagic
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tof_x4_vimeo90k_official.py
76 lines (69 loc) · 2.07 KB
/
tof_x4_vimeo90k_official.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# only testing the official model is supported
exp_name = 'tof_x4_vimeo90k_official'
# model settings
model = dict(
type='EDVR', # use the shared model with EDVR
generator=dict(type='TOFlow', adapt_official_weights=True),
pixel_loss=dict(type='CharbonnierLoss', loss_weight=1.0, reduction='sum'))
# model training and testing settings
train_cfg = None
test_cfg = dict(metrics=['PSNR', 'SSIM'], crop_border=0)
# dataset settings
val_dataset_type = 'SRVid4Dataset'
test_pipeline = [
dict(type='GenerateFrameIndiceswithPadding', padding='reflection_circle'),
dict(
type='LoadImageFromFileList',
io_backend='disk',
key='lq',
flag='unchanged'),
dict(
type='LoadImageFromFileList',
io_backend='disk',
key='gt',
flag='unchanged'),
dict(type='RescaleToZeroOne', keys=['lq', 'gt']),
dict(
type='Normalize',
keys=['lq', 'gt'],
mean=[0, 0, 0],
std=[1, 1, 1],
to_rgb=True),
dict(
type='Collect',
keys=['lq', 'gt'],
meta_keys=['lq_path', 'gt_path', 'key']),
dict(type='FramesToTensor', keys=['lq', 'gt'])
]
demo_pipeline = [
dict(type='GenerateSegmentIndices', interval_list=[1]),
dict(
type='LoadImageFromFileList',
io_backend='disk',
key='lq',
channel_order='rgb'),
dict(type='RescaleToZeroOne', keys=['lq']),
dict(type='FramesToTensor', keys=['lq']),
dict(type='Collect', keys=['lq'], meta_keys=['lq_path', 'key'])
]
data = dict(
workers_per_gpu=8,
test=dict(
type=val_dataset_type,
lq_folder='data/Vid4/BIx4up_direct',
gt_folder='data/Vid4/GT',
ann_file='data/Vid4/meta_info_Vid4_GT.txt',
num_input_frames=7,
pipeline=test_pipeline,
scale=4,
test_mode=True),
)
evaluation = dict(interval=5000, save_image=False, gpu_collect=False)
visual_config = None
# runtime settings
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = f'./work_dirs/{exp_name}'
load_from = None
resume_from = None
workflow = [('train', 1)]