@@ -204,7 +204,7 @@ def forward(
204
204
hidden_states : List [Tensor ],
205
205
contexts : List [List [Tensor ]],
206
206
corr_features : Tensor ,
207
- depth : Tensor ,
207
+ disparity : Tensor ,
208
208
level_processed : List [bool ],
209
209
) -> List [Tensor ]:
210
210
# We call it reverse_i because it has a reversed ordering compared to hidden_states
@@ -215,7 +215,7 @@ def forward(
215
215
# X is concatination of 2x downsampled hidden_dim (or motion_features if no bigger dim) with
216
216
# upsampled hidden_dim (or nothing if not exist).
217
217
if i == 0 :
218
- features = self .motion_encoder (depth , corr_features )
218
+ features = self .motion_encoder (disparity , corr_features )
219
219
else :
220
220
# 2x downsampled features from larger hidden states
221
221
features = F .avg_pool2d (hidden_states [i - 1 ], kernel_size = 3 , stride = 2 , padding = 1 )
@@ -235,14 +235,14 @@ def forward(
235
235
236
236
hidden_states [i ] = gru (hidden_states [i ], features , contexts [i ])
237
237
238
- # NOTE: For slow-fast gru, we dont always want to calculate delta depth for every call on UpdateBlock
239
- # Hence we move the delta depth calculation to the RAFT-Stereo main forward
238
+ # NOTE: For slow-fast gru, we dont always want to calculate delta disparity for every call on UpdateBlock
239
+ # Hence we move the delta disparity calculation to the RAFT-Stereo main forward
240
240
241
241
return hidden_states
242
242
243
243
244
244
class MaskPredictor (raft .MaskPredictor ):
245
- """Mask predictor to be used when upsampling the predicted depth ."""
245
+ """Mask predictor to be used when upsampling the predicted disparity ."""
246
246
247
247
# We add out_channels compared to raft.MaskPredictor
248
248
def __init__ (self , * , in_channels : int , hidden_size : int , out_channels : int , multiplier : float = 0.25 ):
@@ -346,16 +346,16 @@ def __init__(
346
346
corr_pyramid : CorrPyramid1d ,
347
347
corr_block : CorrBlock1d ,
348
348
update_block : MultiLevelUpdateBlock ,
349
- depth_head : nn .Module ,
349
+ disparity_head : nn .Module ,
350
350
mask_predictor : Optional [nn .Module ] = None ,
351
351
slow_fast : bool = False ,
352
352
):
353
353
"""RAFT-Stereo model from
354
354
`RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_.
355
355
356
356
args:
357
- feature_encoder (FeatureEncoder): The feature encoder. Its input is the concatenation of ``image1 `` and ``image2 ``.
358
- context_encoder (MultiLevelContextEncoder): The context encoder. Its input is ``image1 ``.
357
+ feature_encoder (FeatureEncoder): The feature encoder. Its input is the concatenation of ``left_image `` and ``right_image ``.
358
+ context_encoder (MultiLevelContextEncoder): The context encoder. Its input is ``left_image ``.
359
359
It has multi-level output and each level will have 2 parts:
360
360
361
361
- one part will be used as the actual "context", passed to the recurrent unit of the ``update_block``
@@ -370,8 +370,8 @@ def __init__(
370
370
371
371
update_block (MultiLevelUpdateBlock): The update block, which contains the motion encoder, and the recurrent unit.
372
372
It takes as input the hidden state of its recurrent unit, the context, the correlation
373
- features, and the current predicted depth . It outputs an updated hidden state
374
- depth_head (nn.Module): The depth head block will convert from the hidden state into changes in depth .
373
+ features, and the current predicted disparity . It outputs an updated hidden state
374
+ disparity_head (nn.Module): The disparity head block will convert from the hidden state into changes in disparity .
375
375
mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow.
376
376
If ``None`` (default), the flow is upsampled using interpolation.
377
377
slow_fast (bool): A boolean that specify whether we should use slow-fast GRU or not. See RAFT-Stereo paper
@@ -380,6 +380,10 @@ def __init__(
380
380
super ().__init__ ()
381
381
_log_api_usage_once (self )
382
382
383
+ # This indicate that the disparity output will be only have 1 channel (represent horizontal axis).
384
+ # We need this because some stereo matching model like CREStereo might have 2 channel on the output
385
+ self .output_channels = 1
386
+
383
387
self .feature_encoder = feature_encoder
384
388
self .context_encoder = context_encoder
385
389
@@ -388,7 +392,7 @@ def __init__(
388
392
self .corr_pyramid = corr_pyramid
389
393
self .corr_block = corr_block
390
394
self .update_block = update_block
391
- self .depth_head = depth_head
395
+ self .disparity_head = disparity_head
392
396
self .mask_predictor = mask_predictor
393
397
394
398
hidden_dims = self .update_block .hidden_dims
@@ -399,26 +403,29 @@ def __init__(
399
403
)
400
404
self .slow_fast = slow_fast
401
405
402
- def forward (self , image1 : Tensor , image2 : Tensor , num_iters : int = 12 ) -> List [Tensor ]:
406
+ def forward (
407
+ self , left_image : Tensor , right_image : Tensor , flow_init : Optional [Tensor ] = None , num_iters : int = 12
408
+ ) -> List [Tensor ]:
403
409
"""
404
- Return dept predictions on every iterations as a list of Tensor.
410
+ Return disparity predictions on every iterations as a list of Tensor.
405
411
args:
406
- image1 (Tensor): The input left image with layout B, C, H, W
407
- image2 (Tensor): The input right image with layout B, C, H, W
412
+ left_image (Tensor): The input left image with layout B, C, H, W
413
+ right_image (Tensor): The input right image with layout B, C, H, W
414
+ flow_init (Optional[Tensor]): Initial estimate for the disparity. Default: None
408
415
num_iters (int): Number of update block iteration on the largest resolution. Default: 12
409
416
"""
410
- batch_size , _ , h , w = image1 .shape
417
+ batch_size , _ , h , w = left_image .shape
411
418
torch ._assert (
412
- (h , w ) == image2 .shape [- 2 :],
413
- f"input images should have the same shape, instead got ({ h } , { w } ) != { image2 .shape [- 2 :]} " ,
419
+ (h , w ) == right_image .shape [- 2 :],
420
+ f"input images should have the same shape, instead got ({ h } , { w } ) != { right_image .shape [- 2 :]} " ,
414
421
)
415
422
416
423
torch ._assert (
417
424
(h % self .base_downsampling_ratio == 0 and w % self .base_downsampling_ratio == 0 ),
418
425
f"input image H and W should be divisible by { self .base_downsampling_ratio } , insted got H={ h } and W={ w } " ,
419
426
)
420
427
421
- fmaps = self .feature_encoder (torch .cat ([image1 , image2 ], dim = 0 ))
428
+ fmaps = self .feature_encoder (torch .cat ([left_image , right_image ], dim = 0 ))
422
429
fmap1 , fmap2 = torch .chunk (fmaps , chunks = 2 , dim = 0 )
423
430
torch ._assert (
424
431
fmap1 .shape [- 2 :] == (h // self .base_downsampling_ratio , w // self .base_downsampling_ratio ),
@@ -428,7 +435,7 @@ def forward(self, image1: Tensor, image2: Tensor, num_iters: int = 12) -> List[T
428
435
corr_pyramid = self .corr_pyramid (fmap1 , fmap2 )
429
436
430
437
# Multi level contexts
431
- context_outs = self .context_encoder (image1 )
438
+ context_outs = self .context_encoder (left_image )
432
439
433
440
hidden_dims = self .update_block .hidden_dims
434
441
context_out_channels = [context_outs [i ].shape [1 ] - hidden_dims [i ] for i in range (len (context_outs ))]
@@ -448,35 +455,41 @@ def forward(self, image1: Tensor, image2: Tensor, num_iters: int = 12) -> List[T
448
455
coords0 = make_coords_grid (batch_size , Hf , Wf ).to (fmap1 .device )
449
456
coords1 = make_coords_grid (batch_size , Hf , Wf ).to (fmap1 .device )
450
457
451
- depth_predictions = []
458
+ # We use flow_init for cascade inference
459
+ if flow_init is not None :
460
+ coords1 = coords1 + flow_init
461
+
462
+ disparity_predictions = []
452
463
for _ in range (num_iters ):
453
464
coords1 = coords1 .detach () # Don't backpropagate gradients through this branch, see paper
454
465
corr_features = self .corr_block (centroids_coords = coords1 , corr_pyramid = corr_pyramid )
455
466
456
- depth = coords1 - coords0
467
+ disparity = coords1 - coords0
457
468
if self .slow_fast :
458
469
# Using slow_fast GRU (see paper section 3.4). The lower resolution are processed more often
459
470
for i in range (1 , self .num_level ):
460
471
# We only processed the smallest i levels
461
472
level_processed = [False ] * (self .num_level - i ) + [True ] * i
462
473
hidden_states = self .update_block (
463
- hidden_states , contexts , corr_features , depth , level_processed = level_processed
474
+ hidden_states , contexts , corr_features , disparity , level_processed = level_processed
464
475
)
465
476
hidden_states = self .update_block (
466
- hidden_states , contexts , corr_features , depth , level_processed = [True ] * self .num_level
477
+ hidden_states , contexts , corr_features , disparity , level_processed = [True ] * self .num_level
467
478
)
468
- # Take the largest hidden_state to get the depth
479
+ # Take the largest hidden_state to get the disparity
469
480
hidden_state = hidden_states [0 ]
470
- delta_depth = self .depth_head (hidden_state )
471
- # in stereo mode, project depth onto epipolar
472
- delta_depth [:, 1 ] = 0.0
481
+ delta_disparity = self .disparity_head (hidden_state )
482
+ # in stereo mode, project disparity onto epipolar
483
+ delta_disparity [:, 1 ] = 0.0
473
484
474
- coords1 = coords1 + delta_depth
485
+ coords1 = coords1 + delta_disparity
475
486
up_mask = None if self .mask_predictor is None else self .mask_predictor (hidden_state )
476
- upsampled_depth = upsample_flow ((coords1 - coords0 ), up_mask = up_mask , factor = self .base_downsampling_ratio )
477
- depth_predictions .append (upsampled_depth [:, :1 ])
487
+ upsampled_disparity = upsample_flow (
488
+ (coords1 - coords0 ), up_mask = up_mask , factor = self .base_downsampling_ratio
489
+ )
490
+ disparity_predictions .append (upsampled_disparity [:, :1 ])
478
491
479
- return depth_predictions
492
+ return disparity_predictions
480
493
481
494
482
495
def _raft_stereo (
@@ -576,8 +589,8 @@ def _raft_stereo(
576
589
motion_encoder = motion_encoder , hidden_dims = update_block_hidden_dims
577
590
)
578
591
579
- # We use the largest scale hidden_dims of update_block to get the predicted depth
580
- depth_head = kwargs .pop ("depth_head " , None ) or FlowHead (
592
+ # We use the largest scale hidden_dims of update_block to get the predicted disparity
593
+ disparity_head = kwargs .pop ("disparity_head " , None ) or FlowHead (
581
594
in_channels = update_block_hidden_dims [0 ],
582
595
hidden_size = flow_head_hidden_size ,
583
596
)
@@ -598,7 +611,7 @@ def _raft_stereo(
598
611
corr_pyramid = corr_pyramid ,
599
612
corr_block = corr_block ,
600
613
update_block = update_block ,
601
- depth_head = depth_head ,
614
+ disparity_head = disparity_head ,
602
615
mask_predictor = mask_predictor ,
603
616
slow_fast = slow_fast ,
604
617
** kwargs , # not really needed, all params should be consumed by now
0 commit comments