Skip to content

Commit 44c9eda

Browse files
JingyunLiangJingyunLiang
JingyunLiang
authored and
JingyunLiang
committed
Support VFI and STVSR for VRT
1 parent d1c0d9d commit 44c9eda

17 files changed

+8371
-145
lines changed

README.md

+7-9
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
## Training and testing codes for USRNet, DnCNN, FFDNet, SRMD, DPSR, MSRResNet, ESRGAN, BSRGAN, SwinIR, VRT
1+
## Training and testing codes for USRNet, DnCNN, FFDNet, SRMD, DPSR, MSRResNet, ESRGAN, BSRGAN, SwinIR, VRT, RVRT
22
[![download](https://img.shields.io/github/downloads/cszn/KAIR/total.svg)](https://github.com/cszn/KAIR/releases) ![visitors](https://visitor-badge.glitch.me/badge?page_id=cszn/KAIR)
33

44
[Kai Zhang](https://cszn.github.io/)
55

66
*[Computer Vision Lab](https://vision.ee.ethz.ch/the-institute.html), ETH Zurich, Switzerland*
77

88
_______
9-
- **_News (2022-06-01)_**: We release [the training codes](https://github.com/cszn/KAIR/blob/master/docs/README_RVRT.md) of [RVRT ![GitHub Stars](https://img.shields.io/github/stars/JingyunLiang/RVRT?style=social)](https://github.com/JingyunLiang/RVRT) for video SR, deblurring and denoising.
9+
- **_News (2022-10-04)_**: We release [the training codes](https://github.com/cszn/KAIR/blob/master/docs/README_RVRT.md) of [RVRT, NeurlPS2022 ![GitHub Stars](https://img.shields.io/github/stars/JingyunLiang/RVRT?style=social)](https://github.com/JingyunLiang/RVRT) for video SR, deblurring and denoising.
1010

1111
- **_News (2022-05-05)_**: Try the [online demo](https://replicate.com/cszn/scunet) of [SCUNet ![GitHub Stars](https://img.shields.io/github/stars/cszn/SCUNet?style=social)](https://github.com/cszn/SCUNet) for blind real image denoising.
1212

@@ -23,13 +23,11 @@ We did not use the paired noisy/clean data by DND and SIDD during training!*__
2323

2424

2525
- **_News (2022-02-15)_**: We release [the training codes](https://github.com/cszn/KAIR/blob/master/docs/README_VRT.md) of [VRT ![GitHub Stars](https://img.shields.io/github/stars/JingyunLiang/VRT?style=social)](https://github.com/JingyunLiang/VRT) for video SR, deblurring and denoising.
26-
<p align="center">
27-
<a href="https://github.com/JingyunLiang/VRT">
28-
<img width=30% src="https://raw.githubusercontent.com/JingyunLiang/VRT/main/assets/teaser_vsr.gif"/>
29-
<img width=30% src="https://raw.githubusercontent.com/JingyunLiang/VRT/main/assets/teaser_vdb.gif"/>
30-
<img width=30% src="https://raw.githubusercontent.com/JingyunLiang/VRT/main/assets/teaser_vdn.gif"/>
31-
</a>
32-
</p>
26+
![Eg1](https://raw.githubusercontent.com/JingyunLiang/VRT/main/assets/teaser_vsr.gif)
27+
![Eg2](https://raw.githubusercontent.com/JingyunLiang/VRT/main/assets/teaser_vdb.gif)
28+
![Eg3](https://raw.githubusercontent.com/JingyunLiang/VRT/main/assets/teaser_vdn.gif)
29+
![Eg4](https://raw.githubusercontent.com/JingyunLiang/VRT/main/assets/teaser_vfi.gif)
30+
![Eg5](https://raw.githubusercontent.com/JingyunLiang/VRT/main/assets/teaser_stvsr.gif)
3331

3432
- **_News (2021-12-23)_**: Our techniques are adopted in [https://www.amemori.ai/](https://www.amemori.ai/).
3533
- **_News (2021-12-23)_**: Our new work for practical image denoising.

data/dataset_video_test.py

+137-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import torch
33
from os import path as osp
44
import torch.utils.data as data
5+
from torchvision import transforms
6+
from PIL import Image
57

68
import utils.utils_video as utils_video
79

@@ -245,12 +247,12 @@ def __init__(self, opt):
245247
super(VideoTestVimeo90KDataset, self).__init__()
246248
self.opt = opt
247249
self.cache_data = opt['cache_data']
248-
temporal_scale = opt.get('temporal_scale', 1)
250+
self.temporal_scale = opt.get('temporal_scale', 1)
249251
if self.cache_data:
250252
raise NotImplementedError('cache_data in Vimeo90K-Test dataset is not implemented.')
251253
self.gt_root, self.lq_root = opt['dataroot_gt'], opt['dataroot_lq']
252254
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': [], 'idx': [], 'border': []}
253-
neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])][:: temporal_scale]
255+
neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])][:: self.temporal_scale]
254256

255257
with open(opt['meta_info_file'], 'r') as fin:
256258
subfolders = [line.split(' ')[0] for line in fin]
@@ -259,7 +261,7 @@ def __init__(self, opt):
259261
self.data_info['gt_path'].append(gt_path)
260262
lq_paths = [osp.join(self.lq_root, subfolder, f'im{i}.png') for i in neighbor_list]
261263
self.data_info['lq_path'].append(lq_paths)
262-
self.data_info['folder'].append('vimeo90k')
264+
self.data_info['folder'].append(subfolder)
263265
self.data_info['idx'].append(f'{idx}/{len(subfolders)}')
264266
self.data_info['border'].append(0)
265267

@@ -271,7 +273,6 @@ def __getitem__(self, index):
271273
gt_path = self.data_info['gt_path'][index]
272274
imgs_lq = utils_video.read_img_seq(lq_path)
273275
img_gt = utils_video.read_img_seq([gt_path])
274-
img_gt.squeeze_(0)
275276

276277
if self.pad_sequence: # pad the sequence: 7 frames to 8 frames
277278
imgs_lq = torch.cat([imgs_lq, imgs_lq[-1:,...]], dim=0)
@@ -285,9 +286,140 @@ def __getitem__(self, index):
285286
'folder': self.data_info['folder'][index], # folder name
286287
'idx': self.data_info['idx'][index], # e.g., 0/843
287288
'border': self.data_info['border'][index], # 0 for non-border
288-
'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame
289+
'lq_path': lq_path,
290+
'gt_path': [gt_path]
289291
}
290292

291293
def __len__(self):
292294
return len(self.data_info['gt_path'])
293295

296+
297+
class VFI_DAVIS(data.Dataset):
298+
"""Video test dataset for DAVIS dataset in video frame interpolation.
299+
Modified from https://github.com/tarun005/FLAVR/blob/main/dataset/Davis_test.py
300+
"""
301+
302+
def __init__(self, data_root, ext="png"):
303+
304+
super().__init__()
305+
306+
self.data_root = data_root
307+
self.images_sets = []
308+
309+
for label_id in os.listdir(self.data_root):
310+
ctg_imgs_ = sorted(os.listdir(os.path.join(self.data_root , label_id)))
311+
ctg_imgs_ = [os.path.join(self.data_root , label_id , img_id) for img_id in ctg_imgs_]
312+
for start_idx in range(0,len(ctg_imgs_)-6,2):
313+
add_files = ctg_imgs_[start_idx : start_idx+7 : 2]
314+
add_files = add_files[:2] + [ctg_imgs_[start_idx+3]] + add_files[2:]
315+
self.images_sets.append(add_files)
316+
317+
self.transforms = transforms.Compose([
318+
transforms.CenterCrop((480, 840)),
319+
transforms.ToTensor()
320+
])
321+
322+
def __getitem__(self, idx):
323+
324+
imgpaths = self.images_sets[idx]
325+
images = [Image.open(img) for img in imgpaths]
326+
images = [self.transforms(img) for img in images]
327+
328+
return {
329+
'L': torch.stack(images[:2] + images[3:], 0),
330+
'H': images[2].unsqueeze(0),
331+
'folder': str(idx),
332+
'gt_path': ['vfi_result.png'],
333+
}
334+
335+
def __len__(self):
336+
return len(self.images_sets)
337+
338+
339+
class VFI_UCF101(data.Dataset):
340+
"""Video test dataset for UCF101 dataset in video frame interpolation.
341+
Modified from https://github.com/tarun005/FLAVR/blob/main/dataset/ucf101_test.py
342+
"""
343+
344+
def __init__(self, data_root, ext="png"):
345+
super().__init__()
346+
347+
self.data_root = data_root
348+
self.file_list = sorted(os.listdir(self.data_root))
349+
350+
self.transforms = transforms.Compose([
351+
transforms.CenterCrop((224,224)),
352+
transforms.ToTensor(),
353+
])
354+
355+
def __getitem__(self, idx):
356+
357+
imgpath = os.path.join(self.data_root , self.file_list[idx])
358+
imgpaths = [os.path.join(imgpath , "frame0.png") , os.path.join(imgpath , "frame1.png") ,os.path.join(imgpath , "frame2.png") ,os.path.join(imgpath , "frame3.png") ,os.path.join(imgpath , "framet.png")]
359+
360+
images = [Image.open(img) for img in imgpaths]
361+
images = [self.transforms(img) for img in images]
362+
363+
return {
364+
'L': torch.stack(images[:-1], 0),
365+
'H': images[-1].unsqueeze(0),
366+
'folder': self.file_list[idx],
367+
'gt_path': ['vfi_result.png'],
368+
}
369+
370+
def __len__(self):
371+
return len(self.file_list)
372+
373+
374+
class VFI_Vid4(data.Dataset):
375+
"""Video test dataset for Vid4 dataset in video frame interpolation.
376+
Modified from https://github.com/tarun005/FLAVR/blob/main/dataset/Davis_test.py
377+
"""
378+
379+
def __init__(self, data_root, ext="png"):
380+
381+
super().__init__()
382+
383+
self.data_root = data_root
384+
self.images_sets = []
385+
self.data_info = {'lq_path': [], 'gt_path': [], 'folder': []}
386+
self.lq_path = []
387+
self.folder = []
388+
389+
for label_id in os.listdir(self.data_root):
390+
ctg_imgs_ = sorted(os.listdir(os.path.join(self.data_root, label_id)))
391+
ctg_imgs_ = [os.path.join(self.data_root , label_id , img_id) for img_id in ctg_imgs_]
392+
if len(ctg_imgs_) % 2 == 0:
393+
ctg_imgs_.append(ctg_imgs_[-1])
394+
ctg_imgs_.insert(0, None)
395+
ctg_imgs_.insert(0, ctg_imgs_[1])
396+
ctg_imgs_.append(None)
397+
ctg_imgs_.append(ctg_imgs_[-2])
398+
399+
for start_idx in range(0,len(ctg_imgs_)-6,2):
400+
add_files = ctg_imgs_[start_idx : start_idx+7 : 2]
401+
self.data_info['lq_path'].append([os.path.basename(path) for path in add_files])
402+
self.data_info['gt_path'].append(os.path.basename(ctg_imgs_[start_idx + 3]))
403+
self.data_info['folder'].append(label_id)
404+
add_files = add_files[:2] + [ctg_imgs_[start_idx+3]] + add_files[2:]
405+
self.images_sets.append(add_files)
406+
407+
self.transforms = transforms.Compose([
408+
transforms.ToTensor()
409+
])
410+
411+
def __getitem__(self, idx):
412+
imgpaths = self.images_sets[idx]
413+
images = [Image.open(img) for img in imgpaths]
414+
images = [self.transforms(img) for img in images]
415+
416+
return {
417+
'L': torch.stack(images[:2] + images[3:], 0),
418+
'H': images[2].unsqueeze(0),
419+
'folder': self.data_info['folder'][idx],
420+
'lq_path': self.data_info['lq_path'][idx],
421+
'gt_path': [self.data_info['gt_path'][idx]]
422+
}
423+
424+
def __len__(self):
425+
return len(self.images_sets)

data/dataset_video_train.py

+70-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from pathlib import Path
55
import torch.utils.data as data
6+
from torchvision import transforms
67

78
import utils.utils_video as utils_video
89

@@ -302,6 +303,7 @@ def __init__(self, opt):
302303
super(VideoRecurrentTrainVimeoDataset, self).__init__()
303304
self.opt = opt
304305
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
306+
self.temporal_scale = opt.get('temporal_scale', 1)
305307

306308
with open(opt['meta_info_file'], 'r') as fin:
307309
self.keys = [line.split(' ')[0] for line in fin]
@@ -316,15 +318,14 @@ def __init__(self, opt):
316318
self.io_backend_opt['client_keys'] = ['lq', 'gt']
317319

318320
# indices of input images
319-
self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])]
321+
self.neighbor_list = [i + (9 - opt['num_frame']) // 2 for i in range(opt['num_frame'])][::self.temporal_scale]
320322

321323
# temporal augmentation configs
322324
self.random_reverse = opt['random_reverse']
323325
print(f'Random reverse is {self.random_reverse}.')
324326

325327
self.mirror_sequence = opt.get('mirror_sequence', False)
326328
self.pad_sequence = opt.get('pad_sequence', False)
327-
self.neighbor_list = [1, 2, 3, 4, 5, 6, 7]
328329

329330
def __getitem__(self, index):
330331
if self.file_client is None:
@@ -378,9 +379,75 @@ def __getitem__(self, index):
378379
img_gts = torch.cat([img_gts, img_gts[-1:,...]], dim=0)
379380

380381
# img_lqs: (t, c, h, w)
381-
# img_gt: (c, h, w)
382+
# img_gt: (t, c, h, w)
382383
# key: str
383384
return {'L': img_lqs, 'H': img_gts, 'key': key}
384385

385386
def __len__(self):
386387
return len(self.keys)
388+
389+
class VideoRecurrentTrainVimeoVFIDataset(VideoRecurrentTrainVimeoDataset):
390+
391+
def __init__(self, opt):
392+
super(VideoRecurrentTrainVimeoVFIDataset, self).__init__(opt)
393+
self.color_jitter = self.opt.get('color_jitter', False)
394+
395+
if self.color_jitter:
396+
self.transforms_color_jitter = transforms.ColorJitter(0.05, 0.05, 0.05, 0.05)
397+
398+
def __getitem__(self, index):
399+
if self.file_client is None:
400+
self.file_client = utils_video.FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
401+
402+
# random reverse
403+
if self.random_reverse and random.random() < 0.5:
404+
self.neighbor_list.reverse()
405+
406+
scale = self.opt['scale']
407+
gt_size = self.opt['gt_size']
408+
key = self.keys[index]
409+
clip, seq = key.split('/') # key example: 00001/0001
410+
411+
# get the neighboring LQ and GT frames
412+
img_lqs = []
413+
img_gts = []
414+
for neighbor in self.neighbor_list:
415+
if self.is_lmdb:
416+
img_lq_path = f'{clip}/{seq}/im{neighbor}'
417+
else:
418+
img_lq_path = self.lq_root / clip / seq / f'im{neighbor}.png'
419+
# LQ
420+
img_bytes = self.file_client.get(img_lq_path, 'lq')
421+
img_lq = utils_video.imfrombytes(img_bytes, float32=True)
422+
img_lqs.append(img_lq)
423+
424+
# GT
425+
if self.is_lmdb:
426+
img_gt_path = f'{clip}/{seq}/im4'
427+
else:
428+
img_gt_path = self.gt_root / clip / seq / 'im4.png'
429+
430+
img_bytes = self.file_client.get(img_gt_path, 'gt')
431+
img_gt = utils_video.imfrombytes(img_bytes, float32=True)
432+
img_gts.append(img_gt)
433+
434+
# randomly crop
435+
img_gts, img_lqs = utils_video.paired_random_crop(img_gts, img_lqs, gt_size, scale, img_gt_path)
436+
437+
# augmentation - flip, rotate
438+
img_lqs.extend([img_gts])
439+
img_results = utils_video.augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
440+
441+
img_results = utils_video.img2tensor(img_results)
442+
img_results = torch.stack(img_results, dim=0)
443+
444+
if self.color_jitter: # same color_jitter for img_lqs and img_gts
445+
img_results = self.transforms_color_jitter(img_results)
446+
447+
img_lqs = img_results[:-1, ...]
448+
img_gts = img_results[-1:, ...]
449+
450+
# img_lqs: (t, c, h, w)
451+
# img_gt: (t, c, h, w)
452+
# key: str
453+
return {'L': img_lqs, 'H': img_gts, 'key': key}

0 commit comments

Comments
 (0)