Skip to content

Commit 0fcfaa1

Browse files
Add stereo preset transforms (#6549)
* Added transforms for Stereo Matching * changed implicit Y scaling to 0. * Adressed some comments * addressed type hint * Added interpolation random interpolation strategy * Aligned crop get params * fixed bug in RandomErase * Adressed scaling and typos * Adressed occlusion typo * Changed parameter order in F.erase * fixed random erase * Added inference preset transform for stereo matching * added contiguous reshape to output tensors * Adressed comments * Modified the transform preset to use Tuple[int, int] * adressed NITs * added grayscale transform, align resize -> mask * changed max disparity default behaviour * added fixed resize, changed masking in sparse flow masking * update to align with argparse * changed default mask in asymetric pairs * moved grayscale order * changed grayscale api to accept to tensor variant * mypy fix * changed resize specs * adressed nits * added type hints * mypy fix * mypy fix * mypy fix Co-authored-by: Joao Gomes <[email protected]>
1 parent 2c1022e commit 0fcfaa1

File tree

3 files changed

+864
-0
lines changed

3 files changed

+864
-0
lines changed

references/depth/stereo/presets.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from typing import Optional, Tuple, Union
2+
3+
import torch
4+
import transforms as T
5+
6+
7+
class StereoMatchingEvalPreset(torch.nn.Module):
8+
def __init__(
9+
self,
10+
mean: float = 0.5,
11+
std: float = 0.5,
12+
resize_size: Optional[Tuple[int, ...]] = None,
13+
max_disparity: Optional[float] = None,
14+
interpolation_type: str = "bilinear",
15+
use_grayscale: bool = False,
16+
) -> None:
17+
super().__init__()
18+
19+
transforms = [
20+
T.ToTensor(),
21+
T.ConvertImageDtype(torch.float32),
22+
]
23+
24+
if use_grayscale:
25+
transforms.append(T.ConvertToGrayscale())
26+
27+
if resize_size is not None:
28+
transforms.append(T.Resize(resize_size, interpolation_type=interpolation_type))
29+
30+
transforms.extend(
31+
[
32+
T.Normalize(mean=mean, std=std),
33+
T.MakeValidDisparityMask(max_disparity=max_disparity),
34+
T.ValidateModelInput(),
35+
]
36+
)
37+
38+
self.transforms = T.Compose(transforms)
39+
40+
def forward(self, images, disparities, masks):
41+
return self.transforms(images, disparities, masks)
42+
43+
44+
class StereoMatchingTrainPreset(torch.nn.Module):
45+
def __init__(
46+
self,
47+
*,
48+
resize_size: Optional[Tuple[int, ...]],
49+
resize_interpolation_type: str = "bilinear",
50+
# RandomResizeAndCrop params
51+
crop_size: Tuple[int, int],
52+
rescale_prob: float = 1.0,
53+
scaling_type: str = "exponential",
54+
scale_range: Tuple[float, float] = (-0.2, 0.5),
55+
scale_interpolation_type: str = "bilinear",
56+
# convert to grayscale
57+
use_grayscale: bool = False,
58+
# normalization params
59+
mean: float = 0.5,
60+
std: float = 0.5,
61+
# processing device
62+
gpu_transforms: bool = False,
63+
# masking
64+
max_disparity: Optional[int] = 256,
65+
# SpatialShift params
66+
spatial_shift_prob: float = 0.5,
67+
spatial_shift_max_angle: float = 0.5,
68+
spatial_shift_max_displacement: float = 0.5,
69+
spatial_shift_interpolation_type: str = "bilinear",
70+
# AssymetricColorJitter
71+
gamma_range: Tuple[float, float] = (0.8, 1.2),
72+
brightness: Union[int, Tuple[int, int]] = (0.8, 1.2),
73+
contrast: Union[int, Tuple[int, int]] = (0.8, 1.2),
74+
saturation: Union[int, Tuple[int, int]] = 0.0,
75+
hue: Union[int, Tuple[int, int]] = 0.0,
76+
asymmetric_jitter_prob: float = 1.0,
77+
# RandomHorizontalFlip
78+
horizontal_flip_prob: float = 0.5,
79+
# RandomOcclusion
80+
occlusion_prob: float = 0.0,
81+
occlusion_px_range: Tuple[int, int] = (50, 100),
82+
# RandomErase
83+
erase_prob: float = 0.0,
84+
erase_px_range: Tuple[int, int] = (50, 100),
85+
erase_num_repeats: int = 1,
86+
) -> None:
87+
88+
if scaling_type not in ["linear", "exponential"]:
89+
raise ValueError(f"Unknown scaling type: {scaling_type}. Available types: linear, exponential")
90+
91+
super().__init__()
92+
transforms = [T.ToTensor()]
93+
94+
# when fixing size across multiple datasets, we ensure
95+
# that the same size is used for all datasets when cropping
96+
if resize_size is not None:
97+
transforms.append(T.Resize(resize_size, interpolation_type=resize_interpolation_type))
98+
99+
if gpu_transforms:
100+
transforms.append(T.ToGPU())
101+
102+
# color handling
103+
color_transforms = [
104+
T.AsymmetricColorJitter(
105+
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
106+
),
107+
T.AsymetricGammaAdjust(p=asymmetric_jitter_prob, gamma_range=gamma_range),
108+
]
109+
110+
if use_grayscale:
111+
color_transforms.append(T.ConvertToGrayscale())
112+
113+
transforms.extend(color_transforms)
114+
115+
transforms.extend(
116+
[
117+
T.RandomSpatialShift(
118+
p=spatial_shift_prob,
119+
max_angle=spatial_shift_max_angle,
120+
max_px_shift=spatial_shift_max_displacement,
121+
interpolation_type=spatial_shift_interpolation_type,
122+
),
123+
T.ConvertImageDtype(torch.float32),
124+
T.RandomRescaleAndCrop(
125+
crop_size=crop_size,
126+
scale_range=scale_range,
127+
rescale_prob=rescale_prob,
128+
scaling_type=scaling_type,
129+
interpolation_type=scale_interpolation_type,
130+
),
131+
T.RandomHorizontalFlip(horizontal_flip_prob),
132+
# occlusion after flip, otherwise we're occluding the reference image
133+
T.RandomOcclusion(p=occlusion_prob, occlusion_px_range=occlusion_px_range),
134+
T.RandomErase(p=erase_prob, erase_px_range=erase_px_range, max_erase=erase_num_repeats),
135+
T.Normalize(mean=mean, std=std),
136+
T.MakeValidDisparityMask(max_disparity),
137+
T.ValidateModelInput(),
138+
]
139+
)
140+
141+
self.transforms = T.Compose(transforms)
142+
143+
def forward(self, images, disparties, mask):
144+
return self.transforms(images, disparties, mask)

0 commit comments

Comments
 (0)