Skip to content

Commit d1c0d9d

Browse files
author
JingyunLiang
committed
add RVRT training codes
1 parent 06bd194 commit d1c0d9d

19 files changed

+4045
-99
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
*[Computer Vision Lab](https://vision.ee.ethz.ch/the-institute.html), ETH Zurich, Switzerland*
77

88
_______
9+
- **_News (2022-06-01)_**: We release [the training codes](https://github.com/cszn/KAIR/blob/master/docs/README_RVRT.md) of [RVRT ![GitHub Stars](https://img.shields.io/github/stars/JingyunLiang/RVRT?style=social)](https://github.com/JingyunLiang/RVRT) for video SR, deblurring and denoising.
10+
911
- **_News (2022-05-05)_**: Try the [online demo](https://replicate.com/cszn/scunet) of [SCUNet ![GitHub Stars](https://img.shields.io/github/stars/cszn/SCUNet?style=social)](https://github.com/cszn/SCUNet) for blind real image denoising.
1012

1113
- **_News (2022-03-23)_**: We release [the testing codes](https://github.com/cszn/SCUNet) of [SCUNet ![GitHub Stars](https://img.shields.io/github/stars/cszn/SCUNet?style=social)](https://github.com/cszn/SCUNet) for blind real image denoising.

data/dataset_video_test.py

+7-96
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def __len__(self):
129129

130130

131131
class SingleVideoRecurrentTestDataset(data.Dataset):
132-
"""Single ideo test dataset for recurrent architectures, which takes LR video
132+
"""Single video test dataset for recurrent architectures, which takes LR video
133133
frames as input and output corresponding HR video frames (only input LQ path).
134134
135135
More generally, it supports testing dataset with following structures:
@@ -245,11 +245,12 @@ def __init__(self, opt):
245245
super(VideoTestVimeo90KDataset, self).__init__()
246246
self.opt = opt
247247
self.cache_data = opt['cache_data']
248+
temporal_scale = opt.get('temporal_scale', 1)
248249
if self.cache_data:
249250
raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.')
250251
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
251252
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
252-
neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
253+
neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])][:: temporal_scale]
253254

254255
with open(opt['meta_info_file'], 'r') as fin:
255256
subfolders = [line.split(' ')[0] for line in fin]
@@ -263,6 +264,7 @@ def __init__(self, opt):
263264
self.data_info['border'].append(0)
264265

265266
self.pad_sequence = opt.get('pad_sequence', False)
267+
self.mirror_sequence = opt.get('mirror_sequence', False)
266268

267269
def __getitem__(self, index):
268270
lq_path = self.data_info['lq_path'][index]
@@ -274,6 +276,9 @@ def __getitem__(self, index):
274276
if self.pad_sequence: # pad the sequence: 7 frames to 8 frames
275277
imgs_lq = torch.cat([imgs_lq, imgs_lq[-1:,...]], dim=0)
276278

279+
if self.mirror_sequence: # mirror the sequence: 7 frames to 14 frames
280+
imgs_lq = torch.cat([imgs_lq, imgs_lq.flip(0)], dim=0)
281+
277282
return {
278283
'L': imgs_lq, # (t, c, h, w)
279284
'H': img_gt, # (c, h, w)
@@ -286,97 +291,3 @@ def __getitem__(self, index):
286291
def __len__(self):
287292
return len(self.data_info['gt_path'])
288293

289-
290-
class SingleVideoRecurrentTestDataset(data.Dataset):
291-
"""Single Video test dataset (only input LQ path).
292-
293-
Supported datasets: Vid4, REDS4, REDSofficial.
294-
More generally, it supports testing dataset with following structures:
295-
296-
dataroot
297-
├── subfolder1
298-
├── frame000
299-
├── frame001
300-
├── ...
301-
├── subfolder1
302-
├── frame000
303-
├── frame001
304-
├── ...
305-
├── ...
306-
307-
For testing datasets, there is no need to prepare LMDB files.
308-
309-
Args:
310-
opt (dict): Config for train dataset. It contains the following keys:
311-
dataroot_gt (str): Data root path for gt.
312-
dataroot_lq (str): Data root path for lq.
313-
io_backend (dict): IO backend type and other kwarg.
314-
cache_data (bool): Whether to cache testing datasets.
315-
name (str): Dataset name.
316-
meta_info_file (str): The path to the file storing the list of test
317-
folders. If not provided, all the folders in the dataroot will
318-
be used.
319-
num_frame (int): Window size for input frames.
320-
padding (str): Padding mode.
321-
"""
322-
323-
def __init__(self, opt):
324-
super(SingleVideoRecurrentTestDataset, self).__init__()
325-
self.opt = opt
326-
self.cache_data = opt['cache_data']
327-
self.lq_root = opt['dataroot_lq']
328-
self.data_info = {'lq_path': [], 'folder': [], 'idx': [], 'border': []}
329-
# file client (io backend)
330-
self.file_client = None
331-
332-
self.imgs_lq = {}
333-
if 'meta_info_file' in opt:
334-
with open(opt['meta_info_file'], 'r') as fin:
335-
subfolders = [line.split(' ')[0] for line in fin]
336-
subfolders_lq = [osp.join(self.lq_root, key) for key in subfolders]
337-
else:
338-
subfolders_lq = sorted(glob.glob(osp.join(self.lq_root, '*')))
339-
340-
for subfolder_lq in subfolders_lq:
341-
# get frame list for lq and gt
342-
subfolder_name = osp.basename(subfolder_lq)
343-
img_paths_lq = sorted(list(utils_video.scandir(subfolder_lq, full_path=True)))
344-
345-
max_idx = len(img_paths_lq)
346-
347-
self.data_info['lq_path'].extend(img_paths_lq)
348-
self.data_info['folder'].extend([subfolder_name] * max_idx)
349-
for i in range(max_idx):
350-
self.data_info['idx'].append(f'{i}/{max_idx}')
351-
border_l = [0] * max_idx
352-
for i in range(self.opt['num_frame'] // 2):
353-
border_l[i] = 1
354-
border_l[max_idx - i - 1] = 1
355-
self.data_info['border'].extend(border_l)
356-
357-
# cache data or save the frame list
358-
if self.cache_data:
359-
logger.info(f'Cache {subfolder_name} for VideoTestDataset...')
360-
self.imgs_lq[subfolder_name] = utils_video.read_img_seq(img_paths_lq)
361-
else:
362-
self.imgs_lq[subfolder_name] = img_paths_lq
363-
364-
# Find unique folder strings
365-
self.folders = sorted(list(set(self.data_info['folder'])))
366-
367-
def __getitem__(self, index):
368-
folder = self.folders[index]
369-
370-
if self.cache_data:
371-
imgs_lq = self.imgs_lq[folder]
372-
else:
373-
imgs_lq = utils_video.read_img_seq(self.imgs_lq[folder])
374-
375-
return {
376-
'L': imgs_lq,
377-
'folder': folder,
378-
'lq_path': self.imgs_lq[folder],
379-
}
380-
381-
def __len__(self):
382-
return len(self.folders)

data/dataset_video_train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def __init__(self, opt):
322322
self.random_reverse = opt['random_reverse']
323323
print(f'Random reverse is {self.random_reverse}.')
324324

325-
self.flip_sequence = opt.get('flip_sequence', False)
325+
self.mirror_sequence = opt.get('mirror_sequence', False)
326326
self.pad_sequence = opt.get('pad_sequence', False)
327327
self.neighbor_list = [1, 2, 3, 4, 5, 6, 7]
328328

@@ -370,7 +370,7 @@ def __getitem__(self, index):
370370
img_lqs = torch.stack(img_results[:7], dim=0)
371371
img_gts = torch.stack(img_results[7:], dim=0)
372372

373-
if self.flip_sequence: # flip the sequence: 7 frames to 14 frames
373+
if self.mirror_sequence: # mirror the sequence: 7 frames to 14 frames
374374
img_lqs = torch.cat([img_lqs, img_lqs.flip(0)], dim=0)
375375
img_gts = torch.cat([img_gts, img_gts.flip(0)], dim=0)
376376
elif self.pad_sequence: # pad the sequence: 7 frames to 8 frames

0 commit comments

Comments
 (0)