2
2
import torch
3
3
from os import path as osp
4
4
import torch .utils .data as data
5
+ from torchvision import transforms
6
+ from PIL import Image
5
7
6
8
import utils .utils_video as utils_video
7
9
@@ -245,12 +247,12 @@ def __init__(self, opt):
245
247
super (VideoTestVimeo90KDataset , self ).__init__ ()
246
248
self .opt = opt
247
249
self .cache_data = opt ['cache_data' ]
248
- temporal_scale = opt .get ('temporal_scale' , 1 )
250
+ self . temporal_scale = opt .get ('temporal_scale' , 1 )
249
251
if self .cache_data :
250
252
raise NotImplementedError ('cache_data in Vimeo90K-Test dataset is not implemented.' )
251
253
self .gt_root , self .lq_root = opt ['dataroot_gt' ], opt ['dataroot_lq' ]
252
254
self .data_info = {'lq_path' : [], 'gt_path' : [], 'folder' : [], 'idx' : [], 'border' : []}
253
- neighbor_list = [i + (9 - opt ['num_frame' ]) // 2 for i in range (opt ['num_frame' ])][:: temporal_scale ]
255
+ neighbor_list = [i + (9 - opt ['num_frame' ]) // 2 for i in range (opt ['num_frame' ])][:: self . temporal_scale ]
254
256
255
257
with open (opt ['meta_info_file' ], 'r' ) as fin :
256
258
subfolders = [line .split (' ' )[0 ] for line in fin ]
@@ -259,7 +261,7 @@ def __init__(self, opt):
259
261
self .data_info ['gt_path' ].append (gt_path )
260
262
lq_paths = [osp .join (self .lq_root , subfolder , f'im{ i } .png' ) for i in neighbor_list ]
261
263
self .data_info ['lq_path' ].append (lq_paths )
262
- self .data_info ['folder' ].append ('vimeo90k' )
264
+ self .data_info ['folder' ].append (subfolder )
263
265
self .data_info ['idx' ].append (f'{ idx } /{ len (subfolders )} ' )
264
266
self .data_info ['border' ].append (0 )
265
267
@@ -271,7 +273,6 @@ def __getitem__(self, index):
271
273
gt_path = self .data_info ['gt_path' ][index ]
272
274
imgs_lq = utils_video .read_img_seq (lq_path )
273
275
img_gt = utils_video .read_img_seq ([gt_path ])
274
- img_gt .squeeze_ (0 )
275
276
276
277
if self .pad_sequence : # pad the sequence: 7 frames to 8 frames
277
278
imgs_lq = torch .cat ([imgs_lq , imgs_lq [- 1 :,...]], dim = 0 )
@@ -285,9 +286,140 @@ def __getitem__(self, index):
285
286
'folder' : self .data_info ['folder' ][index ], # folder name
286
287
'idx' : self .data_info ['idx' ][index ], # e.g., 0/843
287
288
'border' : self .data_info ['border' ][index ], # 0 for non-border
288
- 'lq_path' : lq_path [self .opt ['num_frame' ] // 2 ] # center frame
289
+ 'lq_path' : lq_path ,
290
+ 'gt_path' : [gt_path ]
289
291
}
290
292
291
293
def __len__ (self ):
292
294
return len (self .data_info ['gt_path' ])
293
295
296
+
297
+ class VFI_DAVIS (data .Dataset ):
298
+ """Video test dataset for DAVIS dataset in video frame interpolation.
299
+ Modified from https://github.com/tarun005/FLAVR/blob/main/dataset/Davis_test.py
300
+ """
301
+
302
+ def __init__ (self , data_root , ext = "png" ):
303
+
304
+ super ().__init__ ()
305
+
306
+ self .data_root = data_root
307
+ self .images_sets = []
308
+
309
+ for label_id in os .listdir (self .data_root ):
310
+ ctg_imgs_ = sorted (os .listdir (os .path .join (self .data_root , label_id )))
311
+ ctg_imgs_ = [os .path .join (self .data_root , label_id , img_id ) for img_id in ctg_imgs_ ]
312
+ for start_idx in range (0 ,len (ctg_imgs_ )- 6 ,2 ):
313
+ add_files = ctg_imgs_ [start_idx : start_idx + 7 : 2 ]
314
+ add_files = add_files [:2 ] + [ctg_imgs_ [start_idx + 3 ]] + add_files [2 :]
315
+ self .images_sets .append (add_files )
316
+
317
+ self .transforms = transforms .Compose ([
318
+ transforms .CenterCrop ((480 , 840 )),
319
+ transforms .ToTensor ()
320
+ ])
321
+
322
+ def __getitem__ (self , idx ):
323
+
324
+ imgpaths = self .images_sets [idx ]
325
+ images = [Image .open (img ) for img in imgpaths ]
326
+ images = [self .transforms (img ) for img in images ]
327
+
328
+ return {
329
+ 'L' : torch .stack (images [:2 ] + images [3 :], 0 ),
330
+ 'H' : images [2 ].unsqueeze (0 ),
331
+ 'folder' : str (idx ),
332
+ 'gt_path' : ['vfi_result.png' ],
333
+ }
334
+
335
+ def __len__ (self ):
336
+ return len (self .images_sets )
337
+
338
+
339
+ class VFI_UCF101 (data .Dataset ):
340
+ """Video test dataset for UCF101 dataset in video frame interpolation.
341
+ Modified from https://github.com/tarun005/FLAVR/blob/main/dataset/ucf101_test.py
342
+ """
343
+
344
+ def __init__ (self , data_root , ext = "png" ):
345
+ super ().__init__ ()
346
+
347
+ self .data_root = data_root
348
+ self .file_list = sorted (os .listdir (self .data_root ))
349
+
350
+ self .transforms = transforms .Compose ([
351
+ transforms .CenterCrop ((224 ,224 )),
352
+ transforms .ToTensor (),
353
+ ])
354
+
355
+ def __getitem__ (self , idx ):
356
+
357
+ imgpath = os .path .join (self .data_root , self .file_list [idx ])
358
+ imgpaths = [os .path .join (imgpath , "frame0.png" ) , os .path .join (imgpath , "frame1.png" ) ,os .path .join (imgpath , "frame2.png" ) ,os .path .join (imgpath , "frame3.png" ) ,os .path .join (imgpath , "framet.png" )]
359
+
360
+ images = [Image .open (img ) for img in imgpaths ]
361
+ images = [self .transforms (img ) for img in images ]
362
+
363
+ return {
364
+ 'L' : torch .stack (images [:- 1 ], 0 ),
365
+ 'H' : images [- 1 ].unsqueeze (0 ),
366
+ 'folder' : self .file_list [idx ],
367
+ 'gt_path' : ['vfi_result.png' ],
368
+ }
369
+
370
+ def __len__ (self ):
371
+ return len (self .file_list )
372
+
373
+
374
+ class VFI_Vid4 (data .Dataset ):
375
+ """Video test dataset for Vid4 dataset in video frame interpolation.
376
+ Modified from https://github.com/tarun005/FLAVR/blob/main/dataset/Davis_test.py
377
+ """
378
+
379
+ def __init__ (self , data_root , ext = "png" ):
380
+
381
+ super ().__init__ ()
382
+
383
+ self .data_root = data_root
384
+ self .images_sets = []
385
+ self .data_info = {'lq_path' : [], 'gt_path' : [], 'folder' : []}
386
+ self .lq_path = []
387
+ self .folder = []
388
+
389
+ for label_id in os .listdir (self .data_root ):
390
+ ctg_imgs_ = sorted (os .listdir (os .path .join (self .data_root , label_id )))
391
+ ctg_imgs_ = [os .path .join (self .data_root , label_id , img_id ) for img_id in ctg_imgs_ ]
392
+ if len (ctg_imgs_ ) % 2 == 0 :
393
+ ctg_imgs_ .append (ctg_imgs_ [- 1 ])
394
+ ctg_imgs_ .insert (0 , None )
395
+ ctg_imgs_ .insert (0 , ctg_imgs_ [1 ])
396
+ ctg_imgs_ .append (None )
397
+ ctg_imgs_ .append (ctg_imgs_ [- 2 ])
398
+
399
+ for start_idx in range (0 ,len (ctg_imgs_ )- 6 ,2 ):
400
+ add_files = ctg_imgs_ [start_idx : start_idx + 7 : 2 ]
401
+ self .data_info ['lq_path' ].append ([os .path .basename (path ) for path in add_files ])
402
+ self .data_info ['gt_path' ].append (os .path .basename (ctg_imgs_ [start_idx + 3 ]))
403
+ self .data_info ['folder' ].append (label_id )
404
+ add_files = add_files [:2 ] + [ctg_imgs_ [start_idx + 3 ]] + add_files [2 :]
405
+ self .images_sets .append (add_files )
406
+
407
+ self .transforms = transforms .Compose ([
408
+ transforms .ToTensor ()
409
+ ])
410
+
411
+ def __getitem__ (self , idx ):
412
+ imgpaths = self .images_sets [idx ]
413
+ images = [Image .open (img ) for img in imgpaths ]
414
+ images = [self .transforms (img ) for img in images ]
415
+
416
+ return {
417
+ 'L' : torch .stack (images [:2 ] + images [3 :], 0 ),
418
+ 'H' : images [2 ].unsqueeze (0 ),
419
+ 'folder' : self .data_info ['folder' ][idx ],
420
+ 'lq_path' : self .data_info ['lq_path' ][idx ],
421
+ 'gt_path' : [self .data_info ['gt_path' ][idx ]]
422
+ }
423
+
424
+ def __len__ (self ):
425
+ return len (self .images_sets )
0 commit comments