Skip to content

Commit 9de8cc9

Browse files
committed
Added transforms for Stereo Matching
1 parent 112accf commit 9de8cc9

File tree

2 files changed

+675
-0
lines changed

2 files changed

+675
-0
lines changed

references/depth/stereo/presets.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from typing import Tuple, Union
2+
3+
import torch
4+
import transforms as T
5+
6+
7+
class StereoMatchingEvalPreset(torch.nn.Module):
8+
def __init__(self, size=None) -> None:
9+
super().__init__()
10+
11+
transforms = [
12+
T.ToTensor(),
13+
T.MakeValidDisparityMask(512), # we keep this transform for API consistency
14+
T.ConvertImageDtype(torch.float32),
15+
T.Normalize(mean=0.5, std=0.5),
16+
T.ValidateModelInput(),
17+
]
18+
19+
if size is not None:
20+
transforms = transforms + [T.Resize(size)]
21+
22+
self.transforms = T.Compose(transforms)
23+
24+
def forward(self, images, disparities, masks):
25+
return self.transforms(images, disparities, masks)
26+
27+
28+
class StereoMatchingTrainPreset(torch.nn.Module):
29+
def __init__(
30+
self,
31+
*,
32+
# RandomResizeAndCrop params
33+
crop_size: Tuple[int, int],
34+
min_scale: float = -0.2,
35+
max_scale: float = 0.5,
36+
resize_prob: float = 1.0,
37+
scaling_type: str = "exponential",
38+
# processing device
39+
gpu_transforms=False,
40+
# masking
41+
max_disparity: int = 256,
42+
# AssymetricColorJitter
43+
gamma_range: Tuple[float, float] = (0.8, 1.2),
44+
brightness: Union[int, Tuple[int, int]] = (0.8, 1.2),
45+
contrast: Union[int, Tuple[int, int]] = (0.8, 1.2),
46+
saturation: Union[int, Tuple[int, int]] = 0.0,
47+
hue: Union[int, Tuple[int, int]] = 0.0,
48+
asymmetric_jitter_prob: float = 1.0,
49+
# RandomHorizontalFlip
50+
do_flip=True,
51+
# RandomOcclusion
52+
occlusion_prob: float = 0.0,
53+
occlusion_min_px: int = 50,
54+
occlusion_max_px: int = 100,
55+
# RandomErase
56+
erase_prob: float = 0.0,
57+
erase_min_px: int = 50,
58+
erase_max_px: int = 100,
59+
erase_num_repeats: int = 1,
60+
) -> None:
61+
62+
if scaling_type not in ["linear", "exponential"]:
63+
raise ValueError(f"Unknown scaling type: {scaling_type}. Available types: linear, exponential")
64+
65+
super().__init__()
66+
transforms = [T.ToTensor()]
67+
if gpu_transforms:
68+
transforms.append(T.ToGPU())
69+
70+
transforms = [
71+
T.AsymmetricColorJitter(
72+
brightness=brightness, contrast=contrast, saturation=saturation, hue=hue, p=asymmetric_jitter_prob
73+
),
74+
T.AsymetricGammaAdjust(p=asymmetric_jitter_prob, gamma_range=gamma_range),
75+
T.RandomSpatialShift(),
76+
T.ConvertImageDtype(torch.float32),
77+
T.RandomResizeAndCrop(
78+
crop_size=crop_size,
79+
min_scale=min_scale,
80+
max_scale=max_scale,
81+
resize_prob=resize_prob,
82+
scaling_type=scaling_type,
83+
),
84+
]
85+
86+
if do_flip:
87+
transforms += [T.RandomHorizontalFlip()]
88+
89+
transforms += [
90+
# occlusion after flip, otherwise we're occluding the reference image
91+
T.RandomOcclusion(p=occlusion_prob, min_px=occlusion_min_px, max_px=occlusion_max_px),
92+
T.RandomErase(p=erase_prob, min_px=erase_min_px, max_px=erase_max_px, num_repeats=erase_num_repeats),
93+
T.Normalize(mean=0.5, std=0.5),
94+
T.MakeValidDisparityMask(max_disparity),
95+
T.ValidateModelInput(),
96+
]
97+
98+
self.transforms = T.Compose(transforms)
99+
100+
def forward(self, images, disparties, mask):
101+
return self.transforms(images, disparties, mask)

0 commit comments

Comments
 (0)