Skip to content

Commit 09779fd

Browse files
YosuaMichaeljdsgomes
authored andcommitted
[fbsync] add crestereo implementation (#6310)
Summary: * crestereo draft implementation * minor model fixes. positional embedding changes. * aligned base configuration with paper * Adressing comments * Broke down Adaptive Correlation Layer. Adressed some other commets. * adressed some nits * changed search size, added output channels to model attrs * changed weights naming * changed from iterations to num_iters * removed _make_coords, adressed comments * fixed jit test * config nit * Changed device arg to str Reviewed By: jdsgomes Differential Revision: D39543279 fbshipit-source-id: c6101958588eb43201f92ff4f687bd32cbbcbbd1 Co-authored-by: Joao Gomes <[email protected]> Co-authored-by: YosuaMichael <[email protected]>
1 parent 18190f7 commit 09779fd

File tree

6 files changed

+1156
-6
lines changed

6 files changed

+1156
-6
lines changed
Binary file not shown.

test/test_prototype_models.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torchvision.prototype import models
66

77

8-
@pytest.mark.parametrize("model_fn", TM.list_model_fns(models.depth.stereo))
8+
@pytest.mark.parametrize("model_fn", (models.depth.stereo.raft_stereo_base,))
99
@pytest.mark.parametrize("model_mode", ("standard", "scripted"))
1010
@pytest.mark.parametrize("dev", cpu_and_gpu())
1111
def test_raft_stereo(model_fn, model_mode, dev):
@@ -35,4 +35,50 @@ def test_raft_stereo(model_fn, model_mode, dev):
3535
), f"The output shape of depth_pred should be [1, 1, 64, 64] but instead it is {preds[0].shape}"
3636

3737
# Test against expected file output
38-
TM._assert_expected(depth_pred.cpu(), name=model_fn.__name__, atol=1e-2, rtol=1e-2)
38+
TM._assert_expected(depth_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2)
39+
40+
41+
@pytest.mark.parametrize("model_fn", (models.depth.stereo.crestereo_base,))
42+
@pytest.mark.parametrize("model_mode", ("standard", "scripted"))
43+
@pytest.mark.parametrize("dev", cpu_and_gpu())
44+
def test_crestereo(model_fn, model_mode, dev):
45+
set_rng_seed(0)
46+
47+
model = model_fn().eval().to(dev)
48+
49+
if model_mode == "scripted":
50+
model = torch.jit.script(model)
51+
52+
img1 = torch.rand(1, 3, 64, 64).to(dev)
53+
img2 = torch.rand(1, 3, 64, 64).to(dev)
54+
iterations = 3
55+
56+
preds = model(img1, img2, flow_init=None, num_iters=iterations)
57+
disparity_pred = preds[-1]
58+
59+
# all the pyramid levels except the highest res make only half the number of iterations
60+
expected_iterations = (iterations // 2) * (len(model.resolutions) - 1)
61+
expected_iterations += iterations
62+
assert (
63+
len(preds) == expected_iterations
64+
), "Number of predictions should be the number of iterations multiplied by the number of pyramid levels"
65+
66+
assert disparity_pred.shape == torch.Size(
67+
[1, 2, 64, 64]
68+
), f"Predicted disparity should have the same spatial shape as the input. Inputs shape {img1.shape[2:]}, Prediction shape {disparity_pred.shape[2:]}"
69+
70+
assert all(
71+
d.shape == torch.Size([1, 2, 64, 64]) for d in preds
72+
), "All predicted disparities are expected to have the same shape"
73+
74+
# test a backward pass with a dummy loss as well
75+
preds = torch.stack(preds, dim=0)
76+
targets = torch.ones_like(preds, requires_grad=False)
77+
loss = torch.nn.functional.mse_loss(preds, targets)
78+
79+
try:
80+
loss.backward()
81+
except Exception as e:
82+
assert False, f"Backward pass failed with an unexpected exception: {e.__class__.__name__} {e}"
83+
84+
TM._assert_expected(disparity_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2)

torchvision/models/optical_flow/_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", alig
1919
return F.grid_sample(img, normalized_grid, mode=mode, align_corners=align_corners)
2020

2121

22-
def make_coords_grid(batch_size: int, h: int, w: int):
23-
coords = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
22+
def make_coords_grid(batch_size: int, h: int, w: int, device: str = "cpu"):
23+
device = torch.device(device)
24+
coords = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij")
2425
coords = torch.stack(coords[::-1], dim=0).float()
2526
return coords[None].repeat(batch_size, 1, 1, 1)
2627

torchvision/models/optical_flow/raft.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
class ResidualBlock(nn.Module):
2828
"""Slightly modified Residual block with extra relu and biases."""
2929

30-
def __init__(self, in_channels, out_channels, *, norm_layer, stride=1):
30+
def __init__(self, in_channels, out_channels, *, norm_layer, stride=1, always_project: bool = False):
3131
super().__init__()
3232

3333
# Note regarding bias=True:
@@ -43,7 +43,10 @@ def __init__(self, in_channels, out_channels, *, norm_layer, stride=1):
4343
out_channels, out_channels, norm_layer=norm_layer, kernel_size=3, bias=True
4444
)
4545

46-
if stride == 1:
46+
# make mypy happy
47+
self.downsample: nn.Module
48+
49+
if stride == 1 and not always_project:
4750
self.downsample = nn.Identity()
4851
else:
4952
self.downsample = Conv2dNormActivation(
@@ -144,6 +147,10 @@ def __init__(
144147
if m.bias is not None:
145148
nn.init.constant_(m.bias, 0)
146149

150+
num_downsamples = len(list(filter(lambda s: s == 2, strides)))
151+
self.output_dim = layers[-1]
152+
self.downsample_factor = 2**num_downsamples
153+
147154
def _make_2_blocks(self, block, in_channels, out_channels, norm_layer, first_stride):
148155
block1 = block(in_channels, out_channels, norm_layer=norm_layer, stride=first_stride)
149156
block2 = block(out_channels, out_channels, norm_layer=norm_layer, stride=1)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .raft_stereo import *
2+
from .crestereo import *

0 commit comments

Comments
 (0)