From ff976a6fdff19e865865029f18668c7451b7ad18 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Tue, 13 Sep 2022 14:24:15 +0100 Subject: [PATCH 1/4] Update raft_stereo to sync forward param with CREStereo and have output_channel --- .../models/depth/stereo/raft_stereo.py | 30 ++++++++++++------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/torchvision/prototype/models/depth/stereo/raft_stereo.py b/torchvision/prototype/models/depth/stereo/raft_stereo.py index a1c0a7bcc8d..ffa5db22ff3 100644 --- a/torchvision/prototype/models/depth/stereo/raft_stereo.py +++ b/torchvision/prototype/models/depth/stereo/raft_stereo.py @@ -354,8 +354,8 @@ def __init__( `RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching `_. args: - feature_encoder (FeatureEncoder): The feature encoder. Its input is the concatenation of ``image1`` and ``image2``. - context_encoder (MultiLevelContextEncoder): The context encoder. Its input is ``image1``. + feature_encoder (FeatureEncoder): The feature encoder. Its input is the concatenation of ``left_image`` and ``right_image``. + context_encoder (MultiLevelContextEncoder): The context encoder. Its input is ``left_image``. It has multi-level output and each level will have 2 parts: - one part will be used as the actual "context", passed to the recurrent unit of the ``update_block`` @@ -380,6 +380,10 @@ def __init__( super().__init__() _log_api_usage_once(self) + # This indicate that the disparity output will be only have 1 channel (represent horizontal axis). + # We need this because some stereo matching model like CREStereo might have 2 channel on the output + self.output_channel = 1 + self.feature_encoder = feature_encoder self.context_encoder = context_encoder @@ -399,18 +403,20 @@ def __init__( ) self.slow_fast = slow_fast - def forward(self, image1: Tensor, image2: Tensor, num_iters: int = 12) -> List[Tensor]: + def forward( + self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 12 + ) -> List[Tensor]: """ Return dept predictions on every iterations as a list of Tensor. args: - image1 (Tensor): The input left image with layout B, C, H, W - image2 (Tensor): The input right image with layout B, C, H, W + left_image (Tensor): The input left image with layout B, C, H, W + right_image (Tensor): The input right image with layout B, C, H, W num_iters (int): Number of update block iteration on the largest resolution. Default: 12 """ - batch_size, _, h, w = image1.shape + batch_size, _, h, w = left_image.shape torch._assert( - (h, w) == image2.shape[-2:], - f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}", + (h, w) == right_image.shape[-2:], + f"input images should have the same shape, instead got ({h}, {w}) != {right_image.shape[-2:]}", ) torch._assert( @@ -418,7 +424,7 @@ def forward(self, image1: Tensor, image2: Tensor, num_iters: int = 12) -> List[T f"input image H and W should be divisible by {self.base_downsampling_ratio}, insted got H={h} and W={w}", ) - fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0)) + fmaps = self.feature_encoder(torch.cat([left_image, right_image], dim=0)) fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0) torch._assert( fmap1.shape[-2:] == (h // self.base_downsampling_ratio, w // self.base_downsampling_ratio), @@ -428,7 +434,7 @@ def forward(self, image1: Tensor, image2: Tensor, num_iters: int = 12) -> List[T corr_pyramid = self.corr_pyramid(fmap1, fmap2) # Multi level contexts - context_outs = self.context_encoder(image1) + context_outs = self.context_encoder(left_image) hidden_dims = self.update_block.hidden_dims context_out_channels = [context_outs[i].shape[1] - hidden_dims[i] for i in range(len(context_outs))] @@ -448,6 +454,10 @@ def forward(self, image1: Tensor, image2: Tensor, num_iters: int = 12) -> List[T coords0 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device) coords1 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device) + # We use flow_init for cascade inference + if flow_init is not None: + coords1 = coords1 + flow_init + depth_predictions = [] for _ in range(num_iters): coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper From 75cb8af1ef95a95b0ccb98a6fd884f6a02dfcb87 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Tue, 13 Sep 2022 15:12:51 +0100 Subject: [PATCH 2/4] Add flow_init param to docstring --- torchvision/prototype/models/depth/stereo/raft_stereo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/prototype/models/depth/stereo/raft_stereo.py b/torchvision/prototype/models/depth/stereo/raft_stereo.py index ffa5db22ff3..a54b479e381 100644 --- a/torchvision/prototype/models/depth/stereo/raft_stereo.py +++ b/torchvision/prototype/models/depth/stereo/raft_stereo.py @@ -411,6 +411,7 @@ def forward( args: left_image (Tensor): The input left image with layout B, C, H, W right_image (Tensor): The input right image with layout B, C, H, W + flow_init (Optional[Tensor]): Initial estimate for the disparity. Default: None num_iters (int): Number of update block iteration on the largest resolution. Default: 12 """ batch_size, _, h, w = left_image.shape From d3972dd819f2df10047b0b57b73e091d0ba83211 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Tue, 13 Sep 2022 15:19:35 +0100 Subject: [PATCH 3/4] Use output_channels instead of output_channel --- torchvision/prototype/models/depth/stereo/raft_stereo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/models/depth/stereo/raft_stereo.py b/torchvision/prototype/models/depth/stereo/raft_stereo.py index a54b479e381..cb43cdbba86 100644 --- a/torchvision/prototype/models/depth/stereo/raft_stereo.py +++ b/torchvision/prototype/models/depth/stereo/raft_stereo.py @@ -382,7 +382,7 @@ def __init__( # This indicate that the disparity output will be only have 1 channel (represent horizontal axis). # We need this because some stereo matching model like CREStereo might have 2 channel on the output - self.output_channel = 1 + self.output_channels = 1 self.feature_encoder = feature_encoder self.context_encoder = context_encoder From 36b5ed4f5d37f9887b516b741d4620defe492e00 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Tue, 13 Sep 2022 18:16:20 +0100 Subject: [PATCH 4/4] Replace depth with disparity since what we predict actually disparity instead of actual depth --- .../models/depth/stereo/raft_stereo.py | 52 ++++++++++--------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/torchvision/prototype/models/depth/stereo/raft_stereo.py b/torchvision/prototype/models/depth/stereo/raft_stereo.py index cb43cdbba86..541a11f0434 100644 --- a/torchvision/prototype/models/depth/stereo/raft_stereo.py +++ b/torchvision/prototype/models/depth/stereo/raft_stereo.py @@ -204,7 +204,7 @@ def forward( hidden_states: List[Tensor], contexts: List[List[Tensor]], corr_features: Tensor, - depth: Tensor, + disparity: Tensor, level_processed: List[bool], ) -> List[Tensor]: # We call it reverse_i because it has a reversed ordering compared to hidden_states @@ -215,7 +215,7 @@ def forward( # X is concatination of 2x downsampled hidden_dim (or motion_features if no bigger dim) with # upsampled hidden_dim (or nothing if not exist). if i == 0: - features = self.motion_encoder(depth, corr_features) + features = self.motion_encoder(disparity, corr_features) else: # 2x downsampled features from larger hidden states features = F.avg_pool2d(hidden_states[i - 1], kernel_size=3, stride=2, padding=1) @@ -235,14 +235,14 @@ def forward( hidden_states[i] = gru(hidden_states[i], features, contexts[i]) - # NOTE: For slow-fast gru, we dont always want to calculate delta depth for every call on UpdateBlock - # Hence we move the delta depth calculation to the RAFT-Stereo main forward + # NOTE: For slow-fast gru, we dont always want to calculate delta disparity for every call on UpdateBlock + # Hence we move the delta disparity calculation to the RAFT-Stereo main forward return hidden_states class MaskPredictor(raft.MaskPredictor): - """Mask predictor to be used when upsampling the predicted depth.""" + """Mask predictor to be used when upsampling the predicted disparity.""" # We add out_channels compared to raft.MaskPredictor def __init__(self, *, in_channels: int, hidden_size: int, out_channels: int, multiplier: float = 0.25): @@ -346,7 +346,7 @@ def __init__( corr_pyramid: CorrPyramid1d, corr_block: CorrBlock1d, update_block: MultiLevelUpdateBlock, - depth_head: nn.Module, + disparity_head: nn.Module, mask_predictor: Optional[nn.Module] = None, slow_fast: bool = False, ): @@ -370,8 +370,8 @@ def __init__( update_block (MultiLevelUpdateBlock): The update block, which contains the motion encoder, and the recurrent unit. It takes as input the hidden state of its recurrent unit, the context, the correlation - features, and the current predicted depth. It outputs an updated hidden state - depth_head (nn.Module): The depth head block will convert from the hidden state into changes in depth. + features, and the current predicted disparity. It outputs an updated hidden state + disparity_head (nn.Module): The disparity head block will convert from the hidden state into changes in disparity. mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow. If ``None`` (default), the flow is upsampled using interpolation. slow_fast (bool): A boolean that specify whether we should use slow-fast GRU or not. See RAFT-Stereo paper @@ -392,7 +392,7 @@ def __init__( self.corr_pyramid = corr_pyramid self.corr_block = corr_block self.update_block = update_block - self.depth_head = depth_head + self.disparity_head = disparity_head self.mask_predictor = mask_predictor hidden_dims = self.update_block.hidden_dims @@ -407,7 +407,7 @@ def forward( self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 12 ) -> List[Tensor]: """ - Return dept predictions on every iterations as a list of Tensor. + Return disparity predictions on every iterations as a list of Tensor. args: left_image (Tensor): The input left image with layout B, C, H, W right_image (Tensor): The input right image with layout B, C, H, W @@ -459,35 +459,37 @@ def forward( if flow_init is not None: coords1 = coords1 + flow_init - depth_predictions = [] + disparity_predictions = [] for _ in range(num_iters): coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper corr_features = self.corr_block(centroids_coords=coords1, corr_pyramid=corr_pyramid) - depth = coords1 - coords0 + disparity = coords1 - coords0 if self.slow_fast: # Using slow_fast GRU (see paper section 3.4). The lower resolution are processed more often for i in range(1, self.num_level): # We only processed the smallest i levels level_processed = [False] * (self.num_level - i) + [True] * i hidden_states = self.update_block( - hidden_states, contexts, corr_features, depth, level_processed=level_processed + hidden_states, contexts, corr_features, disparity, level_processed=level_processed ) hidden_states = self.update_block( - hidden_states, contexts, corr_features, depth, level_processed=[True] * self.num_level + hidden_states, contexts, corr_features, disparity, level_processed=[True] * self.num_level ) - # Take the largest hidden_state to get the depth + # Take the largest hidden_state to get the disparity hidden_state = hidden_states[0] - delta_depth = self.depth_head(hidden_state) - # in stereo mode, project depth onto epipolar - delta_depth[:, 1] = 0.0 + delta_disparity = self.disparity_head(hidden_state) + # in stereo mode, project disparity onto epipolar + delta_disparity[:, 1] = 0.0 - coords1 = coords1 + delta_depth + coords1 = coords1 + delta_disparity up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state) - upsampled_depth = upsample_flow((coords1 - coords0), up_mask=up_mask, factor=self.base_downsampling_ratio) - depth_predictions.append(upsampled_depth[:, :1]) + upsampled_disparity = upsample_flow( + (coords1 - coords0), up_mask=up_mask, factor=self.base_downsampling_ratio + ) + disparity_predictions.append(upsampled_disparity[:, :1]) - return depth_predictions + return disparity_predictions def _raft_stereo( @@ -587,8 +589,8 @@ def _raft_stereo( motion_encoder=motion_encoder, hidden_dims=update_block_hidden_dims ) - # We use the largest scale hidden_dims of update_block to get the predicted depth - depth_head = kwargs.pop("depth_head", None) or FlowHead( + # We use the largest scale hidden_dims of update_block to get the predicted disparity + disparity_head = kwargs.pop("disparity_head", None) or FlowHead( in_channels=update_block_hidden_dims[0], hidden_size=flow_head_hidden_size, ) @@ -609,7 +611,7 @@ def _raft_stereo( corr_pyramid=corr_pyramid, corr_block=corr_block, update_block=update_block, - depth_head=depth_head, + disparity_head=disparity_head, mask_predictor=mask_predictor, slow_fast=slow_fast, **kwargs, # not really needed, all params should be consumed by now