Skip to content

Commit 77940b8

Browse files
authored
Moved pfm file reading into dataset utils (#6270)
* Moved pfm file reading into dataset utils * Made _read_pfm private. Fixed doc format issues.
1 parent 9effc4c commit 77940b8

File tree

2 files changed

+38
-30
lines changed

2 files changed

+38
-30
lines changed

torchvision/datasets/_optical_flow.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import itertools
22
import os
3-
import re
43
from abc import ABC, abstractmethod
54
from glob import glob
65
from pathlib import Path
@@ -10,7 +9,7 @@
109
from PIL import Image
1110

1211
from ..io.image import _read_png_16
13-
from .utils import verify_str_arg
12+
from .utils import verify_str_arg, _read_pfm
1413
from .vision import VisionDataset
1514

1615

@@ -472,31 +471,3 @@ def _read_16bits_png_with_flow_and_valid_mask(file_name):
472471

473472
# For consistency with other datasets, we convert to numpy
474473
return flow.numpy(), valid_flow_mask.numpy()
475-
476-
477-
def _read_pfm(file_name):
478-
"""Read flow in .pfm format"""
479-
480-
with open(file_name, "rb") as f:
481-
header = f.readline().rstrip()
482-
if header != b"PF":
483-
raise ValueError("Invalid PFM file")
484-
485-
dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
486-
if not dim_match:
487-
raise Exception("Malformed PFM header.")
488-
w, h = (int(dim) for dim in dim_match.groups())
489-
490-
scale = float(f.readline().rstrip())
491-
if scale < 0: # little-endian
492-
endian = "<"
493-
scale = -scale
494-
else:
495-
endian = ">" # big-endian
496-
497-
data = np.fromfile(f, dtype=endian + "f")
498-
499-
data = data.reshape(h, w, 3).transpose(2, 0, 1)
500-
data = np.flip(data, axis=1) # flip on h dimension
501-
data = data[:2, :, :]
502-
return data.astype(np.float32)

torchvision/datasets/utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
1919
from urllib.parse import urlparse
2020

21+
import numpy as np
2122
import requests
2223
import torch
2324
from torch.utils.model_zoo import tqdm
@@ -483,3 +484,39 @@ def verify_str_arg(
483484
raise ValueError(msg)
484485

485486
return value
487+
488+
489+
def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
490+
"""Read file in .pfm format. Might contain either 1 or 3 channels of data.
491+
492+
Args:
493+
file_name (str): Path to the file.
494+
slice_channels (int): Number of channels to slice out of the file.
495+
Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc.
496+
"""
497+
498+
with open(file_name, "rb") as f:
499+
header = f.readline().rstrip()
500+
if header not in [b"PF", b"Pf"]:
501+
raise ValueError("Invalid PFM file")
502+
503+
dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline())
504+
if not dim_match:
505+
raise Exception("Malformed PFM header.")
506+
w, h = (int(dim) for dim in dim_match.groups())
507+
508+
scale = float(f.readline().rstrip())
509+
if scale < 0: # little-endian
510+
endian = "<"
511+
scale = -scale
512+
else:
513+
endian = ">" # big-endian
514+
515+
data = np.fromfile(f, dtype=endian + "f")
516+
517+
pfm_channels = 3 if header == b"PF" else 1
518+
519+
data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1)
520+
data = np.flip(data, axis=1) # flip on h dimension
521+
data = data[:slice_channels, :, :]
522+
return data.astype(np.float32)

0 commit comments

Comments
 (0)