Skip to content

Commit 355b278

Browse files
authored
Update RAFT Stereo to be more sync with CREStereo implementation (#6575)
* Update raft_stereo to sync forward param with CREStereo and have output_channel * Add flow_init param to docstring * Use output_channels instead of output_channel * Replace depth with disparity since what we predict actually disparity instead of actual depth
1 parent 1ea73f5 commit 355b278

File tree

1 file changed

+48
-35
lines changed

1 file changed

+48
-35
lines changed

torchvision/prototype/models/depth/stereo/raft_stereo.py

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def forward(
204204
hidden_states: List[Tensor],
205205
contexts: List[List[Tensor]],
206206
corr_features: Tensor,
207-
depth: Tensor,
207+
disparity: Tensor,
208208
level_processed: List[bool],
209209
) -> List[Tensor]:
210210
# We call it reverse_i because it has a reversed ordering compared to hidden_states
@@ -215,7 +215,7 @@ def forward(
215215
# X is concatination of 2x downsampled hidden_dim (or motion_features if no bigger dim) with
216216
# upsampled hidden_dim (or nothing if not exist).
217217
if i == 0:
218-
features = self.motion_encoder(depth, corr_features)
218+
features = self.motion_encoder(disparity, corr_features)
219219
else:
220220
# 2x downsampled features from larger hidden states
221221
features = F.avg_pool2d(hidden_states[i - 1], kernel_size=3, stride=2, padding=1)
@@ -235,14 +235,14 @@ def forward(
235235

236236
hidden_states[i] = gru(hidden_states[i], features, contexts[i])
237237

238-
# NOTE: For slow-fast gru, we dont always want to calculate delta depth for every call on UpdateBlock
239-
# Hence we move the delta depth calculation to the RAFT-Stereo main forward
238+
# NOTE: For slow-fast gru, we dont always want to calculate delta disparity for every call on UpdateBlock
239+
# Hence we move the delta disparity calculation to the RAFT-Stereo main forward
240240

241241
return hidden_states
242242

243243

244244
class MaskPredictor(raft.MaskPredictor):
245-
"""Mask predictor to be used when upsampling the predicted depth."""
245+
"""Mask predictor to be used when upsampling the predicted disparity."""
246246

247247
# We add out_channels compared to raft.MaskPredictor
248248
def __init__(self, *, in_channels: int, hidden_size: int, out_channels: int, multiplier: float = 0.25):
@@ -346,16 +346,16 @@ def __init__(
346346
corr_pyramid: CorrPyramid1d,
347347
corr_block: CorrBlock1d,
348348
update_block: MultiLevelUpdateBlock,
349-
depth_head: nn.Module,
349+
disparity_head: nn.Module,
350350
mask_predictor: Optional[nn.Module] = None,
351351
slow_fast: bool = False,
352352
):
353353
"""RAFT-Stereo model from
354354
`RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_.
355355
356356
args:
357-
feature_encoder (FeatureEncoder): The feature encoder. Its input is the concatenation of ``image1`` and ``image2``.
358-
context_encoder (MultiLevelContextEncoder): The context encoder. Its input is ``image1``.
357+
feature_encoder (FeatureEncoder): The feature encoder. Its input is the concatenation of ``left_image`` and ``right_image``.
358+
context_encoder (MultiLevelContextEncoder): The context encoder. Its input is ``left_image``.
359359
It has multi-level output and each level will have 2 parts:
360360
361361
- one part will be used as the actual "context", passed to the recurrent unit of the ``update_block``
@@ -370,8 +370,8 @@ def __init__(
370370
371371
update_block (MultiLevelUpdateBlock): The update block, which contains the motion encoder, and the recurrent unit.
372372
It takes as input the hidden state of its recurrent unit, the context, the correlation
373-
features, and the current predicted depth. It outputs an updated hidden state
374-
depth_head (nn.Module): The depth head block will convert from the hidden state into changes in depth.
373+
features, and the current predicted disparity. It outputs an updated hidden state
374+
disparity_head (nn.Module): The disparity head block will convert from the hidden state into changes in disparity.
375375
mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow.
376376
If ``None`` (default), the flow is upsampled using interpolation.
377377
slow_fast (bool): A boolean that specify whether we should use slow-fast GRU or not. See RAFT-Stereo paper
@@ -380,6 +380,10 @@ def __init__(
380380
super().__init__()
381381
_log_api_usage_once(self)
382382

383+
# This indicate that the disparity output will be only have 1 channel (represent horizontal axis).
384+
# We need this because some stereo matching model like CREStereo might have 2 channel on the output
385+
self.output_channels = 1
386+
383387
self.feature_encoder = feature_encoder
384388
self.context_encoder = context_encoder
385389

@@ -388,7 +392,7 @@ def __init__(
388392
self.corr_pyramid = corr_pyramid
389393
self.corr_block = corr_block
390394
self.update_block = update_block
391-
self.depth_head = depth_head
395+
self.disparity_head = disparity_head
392396
self.mask_predictor = mask_predictor
393397

394398
hidden_dims = self.update_block.hidden_dims
@@ -399,26 +403,29 @@ def __init__(
399403
)
400404
self.slow_fast = slow_fast
401405

402-
def forward(self, image1: Tensor, image2: Tensor, num_iters: int = 12) -> List[Tensor]:
406+
def forward(
407+
self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 12
408+
) -> List[Tensor]:
403409
"""
404-
Return dept predictions on every iterations as a list of Tensor.
410+
Return disparity predictions on every iterations as a list of Tensor.
405411
args:
406-
image1 (Tensor): The input left image with layout B, C, H, W
407-
image2 (Tensor): The input right image with layout B, C, H, W
412+
left_image (Tensor): The input left image with layout B, C, H, W
413+
right_image (Tensor): The input right image with layout B, C, H, W
414+
flow_init (Optional[Tensor]): Initial estimate for the disparity. Default: None
408415
num_iters (int): Number of update block iteration on the largest resolution. Default: 12
409416
"""
410-
batch_size, _, h, w = image1.shape
417+
batch_size, _, h, w = left_image.shape
411418
torch._assert(
412-
(h, w) == image2.shape[-2:],
413-
f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}",
419+
(h, w) == right_image.shape[-2:],
420+
f"input images should have the same shape, instead got ({h}, {w}) != {right_image.shape[-2:]}",
414421
)
415422

416423
torch._assert(
417424
(h % self.base_downsampling_ratio == 0 and w % self.base_downsampling_ratio == 0),
418425
f"input image H and W should be divisible by {self.base_downsampling_ratio}, insted got H={h} and W={w}",
419426
)
420427

421-
fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0))
428+
fmaps = self.feature_encoder(torch.cat([left_image, right_image], dim=0))
422429
fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0)
423430
torch._assert(
424431
fmap1.shape[-2:] == (h // self.base_downsampling_ratio, w // self.base_downsampling_ratio),
@@ -428,7 +435,7 @@ def forward(self, image1: Tensor, image2: Tensor, num_iters: int = 12) -> List[T
428435
corr_pyramid = self.corr_pyramid(fmap1, fmap2)
429436

430437
# Multi level contexts
431-
context_outs = self.context_encoder(image1)
438+
context_outs = self.context_encoder(left_image)
432439

433440
hidden_dims = self.update_block.hidden_dims
434441
context_out_channels = [context_outs[i].shape[1] - hidden_dims[i] for i in range(len(context_outs))]
@@ -448,35 +455,41 @@ def forward(self, image1: Tensor, image2: Tensor, num_iters: int = 12) -> List[T
448455
coords0 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device)
449456
coords1 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device)
450457

451-
depth_predictions = []
458+
# We use flow_init for cascade inference
459+
if flow_init is not None:
460+
coords1 = coords1 + flow_init
461+
462+
disparity_predictions = []
452463
for _ in range(num_iters):
453464
coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper
454465
corr_features = self.corr_block(centroids_coords=coords1, corr_pyramid=corr_pyramid)
455466

456-
depth = coords1 - coords0
467+
disparity = coords1 - coords0
457468
if self.slow_fast:
458469
# Using slow_fast GRU (see paper section 3.4). The lower resolution are processed more often
459470
for i in range(1, self.num_level):
460471
# We only processed the smallest i levels
461472
level_processed = [False] * (self.num_level - i) + [True] * i
462473
hidden_states = self.update_block(
463-
hidden_states, contexts, corr_features, depth, level_processed=level_processed
474+
hidden_states, contexts, corr_features, disparity, level_processed=level_processed
464475
)
465476
hidden_states = self.update_block(
466-
hidden_states, contexts, corr_features, depth, level_processed=[True] * self.num_level
477+
hidden_states, contexts, corr_features, disparity, level_processed=[True] * self.num_level
467478
)
468-
# Take the largest hidden_state to get the depth
479+
# Take the largest hidden_state to get the disparity
469480
hidden_state = hidden_states[0]
470-
delta_depth = self.depth_head(hidden_state)
471-
# in stereo mode, project depth onto epipolar
472-
delta_depth[:, 1] = 0.0
481+
delta_disparity = self.disparity_head(hidden_state)
482+
# in stereo mode, project disparity onto epipolar
483+
delta_disparity[:, 1] = 0.0
473484

474-
coords1 = coords1 + delta_depth
485+
coords1 = coords1 + delta_disparity
475486
up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state)
476-
upsampled_depth = upsample_flow((coords1 - coords0), up_mask=up_mask, factor=self.base_downsampling_ratio)
477-
depth_predictions.append(upsampled_depth[:, :1])
487+
upsampled_disparity = upsample_flow(
488+
(coords1 - coords0), up_mask=up_mask, factor=self.base_downsampling_ratio
489+
)
490+
disparity_predictions.append(upsampled_disparity[:, :1])
478491

479-
return depth_predictions
492+
return disparity_predictions
480493

481494

482495
def _raft_stereo(
@@ -576,8 +589,8 @@ def _raft_stereo(
576589
motion_encoder=motion_encoder, hidden_dims=update_block_hidden_dims
577590
)
578591

579-
# We use the largest scale hidden_dims of update_block to get the predicted depth
580-
depth_head = kwargs.pop("depth_head", None) or FlowHead(
592+
# We use the largest scale hidden_dims of update_block to get the predicted disparity
593+
disparity_head = kwargs.pop("disparity_head", None) or FlowHead(
581594
in_channels=update_block_hidden_dims[0],
582595
hidden_size=flow_head_hidden_size,
583596
)
@@ -598,7 +611,7 @@ def _raft_stereo(
598611
corr_pyramid=corr_pyramid,
599612
corr_block=corr_block,
600613
update_block=update_block,
601-
depth_head=depth_head,
614+
disparity_head=disparity_head,
602615
mask_predictor=mask_predictor,
603616
slow_fast=slow_fast,
604617
**kwargs, # not really needed, all params should be consumed by now

0 commit comments

Comments
 (0)