Skip to content

add crestereo implementation #6310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
39efae5
crestereo draft implementation
TeodorPoncu Jul 25, 2022
93bc4c6
minor model fixes. positional embedding changes.
TeodorPoncu Jul 29, 2022
7f2d924
aligned base configuration with paper
TeodorPoncu Jul 29, 2022
bf7f2ff
Merge branch 'main' into add-crestereo-model
jdsgomes Aug 15, 2022
9fa1b50
Adressing comments
TeodorPoncu Aug 16, 2022
1e63f3d
Merge branch 'add-crestereo-model' of https://github.com/pytorch/visi…
TeodorPoncu Aug 16, 2022
f843014
Broke down Adaptive Correlation Layer. Adressed some other commets.
TeodorPoncu Sep 12, 2022
011582d
adressed some nits
TeodorPoncu Sep 12, 2022
396bdcd
changed search size, added output channels to model attrs
TeodorPoncu Sep 13, 2022
d3e6fd5
changed weights naming
TeodorPoncu Sep 13, 2022
07c15da
changed from iterations to num_iters
TeodorPoncu Sep 13, 2022
0ba1c37
removed _make_coords, adressed comments
TeodorPoncu Sep 14, 2022
b7f269d
Merge branch 'main' into add-crestereo-model
TeodorPoncu Sep 14, 2022
13f81f6
fixed jit test
TeodorPoncu Sep 14, 2022
94d2d43
Merge branch 'main' into add-crestereo-model
jdsgomes Sep 14, 2022
b224caa
Merge branch 'main' into add-crestereo-model
jdsgomes Sep 14, 2022
e9dfe3e
Merge branch 'main' into add-crestereo-model
YosuaMichael Sep 15, 2022
016501b
config nit
TeodorPoncu Sep 15, 2022
c5cc082
Merge branch 'add-crestereo-model' of https://github.com/pytorch/visi…
TeodorPoncu Sep 15, 2022
d74d28c
Changed device arg to str
TeodorPoncu Sep 15, 2022
3989652
Merge branch 'main' into add-crestereo-model
YosuaMichael Sep 15, 2022
1307cd3
Merge branch 'main' into add-crestereo-model
YosuaMichael Sep 15, 2022
6cead36
Merge branch 'main' into add-crestereo-model
jdsgomes Sep 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
50 changes: 48 additions & 2 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torchvision.prototype import models


@pytest.mark.parametrize("model_fn", TM.list_model_fns(models.depth.stereo))
@pytest.mark.parametrize("model_fn", (models.depth.stereo.raft_stereo_base,))
@pytest.mark.parametrize("model_mode", ("standard", "scripted"))
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_raft_stereo(model_fn, model_mode, dev):
Expand Down Expand Up @@ -35,4 +35,50 @@ def test_raft_stereo(model_fn, model_mode, dev):
), f"The output shape of depth_pred should be [1, 1, 64, 64] but instead it is {preds[0].shape}"

# Test against expected file output
TM._assert_expected(depth_pred.cpu(), name=model_fn.__name__, atol=1e-2, rtol=1e-2)
TM._assert_expected(depth_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2)


@pytest.mark.parametrize("model_fn", (models.depth.stereo.crestereo_base,))
@pytest.mark.parametrize("model_mode", ("standard", "scripted"))
@pytest.mark.parametrize("dev", cpu_and_gpu())
def test_crestereo(model_fn, model_mode, dev):
set_rng_seed(0)

model = model_fn().eval().to(dev)

if model_mode == "scripted":
model = torch.jit.script(model)

img1 = torch.rand(1, 3, 64, 64).to(dev)
img2 = torch.rand(1, 3, 64, 64).to(dev)
iterations = 3

preds = model(img1, img2, flow_init=None, num_iters=iterations)
disparity_pred = preds[-1]

# all the pyramid levels except the highest res make only half the number of iterations
expected_iterations = (iterations // 2) * (len(model.resolutions) - 1)
expected_iterations += iterations
assert (
len(preds) == expected_iterations
), "Number of predictions should be the number of iterations multiplied by the number of pyramid levels"

assert disparity_pred.shape == torch.Size(
[1, 2, 64, 64]
), f"Predicted disparity should have the same spatial shape as the input. Inputs shape {img1.shape[2:]}, Prediction shape {disparity_pred.shape[2:]}"

assert all(
d.shape == torch.Size([1, 2, 64, 64]) for d in preds
), "All predicted disparities are expected to have the same shape"

# test a backward pass with a dummy loss as well
preds = torch.stack(preds, dim=0)
targets = torch.ones_like(preds, requires_grad=False)
loss = torch.nn.functional.mse_loss(preds, targets)

try:
loss.backward()
except Exception as e:
assert False, f"Backward pass failed with an unexpected exception: {e.__class__.__name__} {e}"

TM._assert_expected(disparity_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2)
4 changes: 2 additions & 2 deletions torchvision/models/optical_flow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def grid_sample(img: Tensor, absolute_grid: Tensor, mode: str = "bilinear", alig
return F.grid_sample(img, normalized_grid, mode=mode, align_corners=align_corners)


def make_coords_grid(batch_size: int, h: int, w: int):
coords = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
def make_coords_grid(batch_size: int, h: int, w: int, device: torch.device = torch.device("cpu")):
coords = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing="ij")
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch_size, 1, 1, 1)

Expand Down
11 changes: 9 additions & 2 deletions torchvision/models/optical_flow/raft.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class ResidualBlock(nn.Module):
"""Slightly modified Residual block with extra relu and biases."""

def __init__(self, in_channels, out_channels, *, norm_layer, stride=1):
def __init__(self, in_channels, out_channels, *, norm_layer, stride=1, always_project: bool = False):
super().__init__()

# Note regarding bias=True:
Expand All @@ -43,7 +43,10 @@ def __init__(self, in_channels, out_channels, *, norm_layer, stride=1):
out_channels, out_channels, norm_layer=norm_layer, kernel_size=3, bias=True
)

if stride == 1:
# make mypy happy
self.downsample: nn.Module

if stride == 1 and not always_project:
self.downsample = nn.Identity()
else:
self.downsample = Conv2dNormActivation(
Expand Down Expand Up @@ -144,6 +147,10 @@ def __init__(
if m.bias is not None:
nn.init.constant_(m.bias, 0)

num_downsamples = len(list(filter(lambda s: s == 2, strides)))
self.output_dim = layers[-1]
self.downsample_factor = 2**num_downsamples

def _make_2_blocks(self, block, in_channels, out_channels, norm_layer, first_stride):
block1 = block(in_channels, out_channels, norm_layer=norm_layer, stride=first_stride)
block2 = block(out_channels, out_channels, norm_layer=norm_layer, stride=1)
Expand Down
1 change: 1 addition & 0 deletions torchvision/prototype/models/depth/stereo/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .raft_stereo import *
from .crestereo import *
Loading