|
5 | 5 | from torchvision.prototype import models
|
6 | 6 |
|
7 | 7 |
|
8 |
| -@pytest.mark.parametrize("model_fn", TM.list_model_fns(models.depth.stereo)) |
| 8 | +@pytest.mark.parametrize("model_fn", (models.depth.stereo.raft_stereo_base,)) |
9 | 9 | @pytest.mark.parametrize("model_mode", ("standard", "scripted"))
|
10 | 10 | @pytest.mark.parametrize("dev", cpu_and_gpu())
|
11 | 11 | def test_raft_stereo(model_fn, model_mode, dev):
|
@@ -35,4 +35,50 @@ def test_raft_stereo(model_fn, model_mode, dev):
|
35 | 35 | ), f"The output shape of depth_pred should be [1, 1, 64, 64] but instead it is {preds[0].shape}"
|
36 | 36 |
|
37 | 37 | # Test against expected file output
|
38 |
| - TM._assert_expected(depth_pred.cpu(), name=model_fn.__name__, atol=1e-2, rtol=1e-2) |
| 38 | + TM._assert_expected(depth_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2) |
| 39 | + |
| 40 | + |
| 41 | +@pytest.mark.parametrize("model_fn", (models.depth.stereo.crestereo_base,)) |
| 42 | +@pytest.mark.parametrize("model_mode", ("standard", "scripted")) |
| 43 | +@pytest.mark.parametrize("dev", cpu_and_gpu()) |
| 44 | +def test_crestereo(model_fn, model_mode, dev): |
| 45 | + set_rng_seed(0) |
| 46 | + |
| 47 | + model = model_fn().eval().to(dev) |
| 48 | + |
| 49 | + if model_mode == "scripted": |
| 50 | + model = torch.jit.script(model) |
| 51 | + |
| 52 | + img1 = torch.rand(1, 3, 64, 64).to(dev) |
| 53 | + img2 = torch.rand(1, 3, 64, 64).to(dev) |
| 54 | + iterations = 3 |
| 55 | + |
| 56 | + preds = model(img1, img2, flow_init=None, num_iters=iterations) |
| 57 | + disparity_pred = preds[-1] |
| 58 | + |
| 59 | + # all the pyramid levels except the highest res make only half the number of iterations |
| 60 | + expected_iterations = (iterations // 2) * (len(model.resolutions) - 1) |
| 61 | + expected_iterations += iterations |
| 62 | + assert ( |
| 63 | + len(preds) == expected_iterations |
| 64 | + ), "Number of predictions should be the number of iterations multiplied by the number of pyramid levels" |
| 65 | + |
| 66 | + assert disparity_pred.shape == torch.Size( |
| 67 | + [1, 2, 64, 64] |
| 68 | + ), f"Predicted disparity should have the same spatial shape as the input. Inputs shape {img1.shape[2:]}, Prediction shape {disparity_pred.shape[2:]}" |
| 69 | + |
| 70 | + assert all( |
| 71 | + d.shape == torch.Size([1, 2, 64, 64]) for d in preds |
| 72 | + ), "All predicted disparities are expected to have the same shape" |
| 73 | + |
| 74 | + # test a backward pass with a dummy loss as well |
| 75 | + preds = torch.stack(preds, dim=0) |
| 76 | + targets = torch.ones_like(preds, requires_grad=False) |
| 77 | + loss = torch.nn.functional.mse_loss(preds, targets) |
| 78 | + |
| 79 | + try: |
| 80 | + loss.backward() |
| 81 | + except Exception as e: |
| 82 | + assert False, f"Backward pass failed with an unexpected exception: {e.__class__.__name__} {e}" |
| 83 | + |
| 84 | + TM._assert_expected(disparity_pred, name=model_fn.__name__, atol=1e-2, rtol=1e-2) |
0 commit comments