Skip to content

Update RAFT Stereo to be more sync with CREStereo implementation #6575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 14, 2022
Merged
83 changes: 48 additions & 35 deletions torchvision/prototype/models/depth/stereo/raft_stereo.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def forward(
hidden_states: List[Tensor],
contexts: List[List[Tensor]],
corr_features: Tensor,
depth: Tensor,
disparity: Tensor,
level_processed: List[bool],
) -> List[Tensor]:
# We call it reverse_i because it has a reversed ordering compared to hidden_states
Expand All @@ -215,7 +215,7 @@ def forward(
# X is concatination of 2x downsampled hidden_dim (or motion_features if no bigger dim) with
# upsampled hidden_dim (or nothing if not exist).
if i == 0:
features = self.motion_encoder(depth, corr_features)
features = self.motion_encoder(disparity, corr_features)
else:
# 2x downsampled features from larger hidden states
features = F.avg_pool2d(hidden_states[i - 1], kernel_size=3, stride=2, padding=1)
Expand All @@ -235,14 +235,14 @@ def forward(

hidden_states[i] = gru(hidden_states[i], features, contexts[i])

# NOTE: For slow-fast gru, we dont always want to calculate delta depth for every call on UpdateBlock
# Hence we move the delta depth calculation to the RAFT-Stereo main forward
# NOTE: For slow-fast gru, we dont always want to calculate delta disparity for every call on UpdateBlock
# Hence we move the delta disparity calculation to the RAFT-Stereo main forward

return hidden_states


class MaskPredictor(raft.MaskPredictor):
"""Mask predictor to be used when upsampling the predicted depth."""
"""Mask predictor to be used when upsampling the predicted disparity."""

# We add out_channels compared to raft.MaskPredictor
def __init__(self, *, in_channels: int, hidden_size: int, out_channels: int, multiplier: float = 0.25):
Expand Down Expand Up @@ -346,16 +346,16 @@ def __init__(
corr_pyramid: CorrPyramid1d,
corr_block: CorrBlock1d,
update_block: MultiLevelUpdateBlock,
depth_head: nn.Module,
disparity_head: nn.Module,
mask_predictor: Optional[nn.Module] = None,
slow_fast: bool = False,
):
"""RAFT-Stereo model from
`RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching <https://arxiv.org/abs/2109.07547>`_.

args:
feature_encoder (FeatureEncoder): The feature encoder. Its input is the concatenation of ``image1`` and ``image2``.
context_encoder (MultiLevelContextEncoder): The context encoder. Its input is ``image1``.
feature_encoder (FeatureEncoder): The feature encoder. Its input is the concatenation of ``left_image`` and ``right_image``.
context_encoder (MultiLevelContextEncoder): The context encoder. Its input is ``left_image``.
It has multi-level output and each level will have 2 parts:

- one part will be used as the actual "context", passed to the recurrent unit of the ``update_block``
Expand All @@ -370,8 +370,8 @@ def __init__(

update_block (MultiLevelUpdateBlock): The update block, which contains the motion encoder, and the recurrent unit.
It takes as input the hidden state of its recurrent unit, the context, the correlation
features, and the current predicted depth. It outputs an updated hidden state
depth_head (nn.Module): The depth head block will convert from the hidden state into changes in depth.
features, and the current predicted disparity. It outputs an updated hidden state
disparity_head (nn.Module): The disparity head block will convert from the hidden state into changes in disparity.
mask_predictor (nn.Module, optional): Predicts the mask that will be used to upsample the predicted flow.
If ``None`` (default), the flow is upsampled using interpolation.
slow_fast (bool): A boolean that specify whether we should use slow-fast GRU or not. See RAFT-Stereo paper
Expand All @@ -380,6 +380,10 @@ def __init__(
super().__init__()
_log_api_usage_once(self)

# This indicate that the disparity output will be only have 1 channel (represent horizontal axis).
# We need this because some stereo matching model like CREStereo might have 2 channel on the output
self.output_channels = 1

self.feature_encoder = feature_encoder
self.context_encoder = context_encoder

Expand All @@ -388,7 +392,7 @@ def __init__(
self.corr_pyramid = corr_pyramid
self.corr_block = corr_block
self.update_block = update_block
self.depth_head = depth_head
self.disparity_head = disparity_head
self.mask_predictor = mask_predictor

hidden_dims = self.update_block.hidden_dims
Expand All @@ -399,26 +403,29 @@ def __init__(
)
self.slow_fast = slow_fast

def forward(self, image1: Tensor, image2: Tensor, num_iters: int = 12) -> List[Tensor]:
def forward(
self, left_image: Tensor, right_image: Tensor, flow_init: Optional[Tensor] = None, num_iters: int = 12
) -> List[Tensor]:
"""
Return dept predictions on every iterations as a list of Tensor.
Return disparity predictions on every iterations as a list of Tensor.
args:
image1 (Tensor): The input left image with layout B, C, H, W
image2 (Tensor): The input right image with layout B, C, H, W
left_image (Tensor): The input left image with layout B, C, H, W
right_image (Tensor): The input right image with layout B, C, H, W
flow_init (Optional[Tensor]): Initial estimate for the disparity. Default: None
num_iters (int): Number of update block iteration on the largest resolution. Default: 12
"""
batch_size, _, h, w = image1.shape
batch_size, _, h, w = left_image.shape
torch._assert(
(h, w) == image2.shape[-2:],
f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}",
(h, w) == right_image.shape[-2:],
f"input images should have the same shape, instead got ({h}, {w}) != {right_image.shape[-2:]}",
)

torch._assert(
(h % self.base_downsampling_ratio == 0 and w % self.base_downsampling_ratio == 0),
f"input image H and W should be divisible by {self.base_downsampling_ratio}, insted got H={h} and W={w}",
)

fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0))
fmaps = self.feature_encoder(torch.cat([left_image, right_image], dim=0))
fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0)
torch._assert(
fmap1.shape[-2:] == (h // self.base_downsampling_ratio, w // self.base_downsampling_ratio),
Expand All @@ -428,7 +435,7 @@ def forward(self, image1: Tensor, image2: Tensor, num_iters: int = 12) -> List[T
corr_pyramid = self.corr_pyramid(fmap1, fmap2)

# Multi level contexts
context_outs = self.context_encoder(image1)
context_outs = self.context_encoder(left_image)

hidden_dims = self.update_block.hidden_dims
context_out_channels = [context_outs[i].shape[1] - hidden_dims[i] for i in range(len(context_outs))]
Expand All @@ -448,35 +455,41 @@ def forward(self, image1: Tensor, image2: Tensor, num_iters: int = 12) -> List[T
coords0 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device)
coords1 = make_coords_grid(batch_size, Hf, Wf).to(fmap1.device)

depth_predictions = []
# We use flow_init for cascade inference
if flow_init is not None:
coords1 = coords1 + flow_init
Comment on lines +459 to +460
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only change for which I don't have context, but I understand it's a new feature.

Copy link
Contributor

@TeodorPoncu TeodorPoncu Sep 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When wanting to perform inference at a resolution significantly larger than that at which the model is trained, you can perform cascaded inference.

Cascaded inference first computes the flow for a downsampled version of the image, and uses that flow as a prior for the full resolution image.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to give more context:
Here is the code for the cascaded inference: https://github.com/pytorch/vision/blob/test-crestereo-training/references/stereo_matching/evaluation.py#L57

And here is reference to original raft implementation on this part: https://github.com/princeton-vl/RAFT-Stereo/blob/main/core/raft_stereo.py#L104


disparity_predictions = []
for _ in range(num_iters):
coords1 = coords1.detach() # Don't backpropagate gradients through this branch, see paper
corr_features = self.corr_block(centroids_coords=coords1, corr_pyramid=corr_pyramid)

depth = coords1 - coords0
disparity = coords1 - coords0
if self.slow_fast:
# Using slow_fast GRU (see paper section 3.4). The lower resolution are processed more often
for i in range(1, self.num_level):
# We only processed the smallest i levels
level_processed = [False] * (self.num_level - i) + [True] * i
hidden_states = self.update_block(
hidden_states, contexts, corr_features, depth, level_processed=level_processed
hidden_states, contexts, corr_features, disparity, level_processed=level_processed
)
hidden_states = self.update_block(
hidden_states, contexts, corr_features, depth, level_processed=[True] * self.num_level
hidden_states, contexts, corr_features, disparity, level_processed=[True] * self.num_level
)
# Take the largest hidden_state to get the depth
# Take the largest hidden_state to get the disparity
hidden_state = hidden_states[0]
delta_depth = self.depth_head(hidden_state)
# in stereo mode, project depth onto epipolar
delta_depth[:, 1] = 0.0
delta_disparity = self.disparity_head(hidden_state)
# in stereo mode, project disparity onto epipolar
delta_disparity[:, 1] = 0.0

coords1 = coords1 + delta_depth
coords1 = coords1 + delta_disparity
up_mask = None if self.mask_predictor is None else self.mask_predictor(hidden_state)
upsampled_depth = upsample_flow((coords1 - coords0), up_mask=up_mask, factor=self.base_downsampling_ratio)
depth_predictions.append(upsampled_depth[:, :1])
upsampled_disparity = upsample_flow(
(coords1 - coords0), up_mask=up_mask, factor=self.base_downsampling_ratio
)
disparity_predictions.append(upsampled_disparity[:, :1])

return depth_predictions
return disparity_predictions


def _raft_stereo(
Expand Down Expand Up @@ -576,8 +589,8 @@ def _raft_stereo(
motion_encoder=motion_encoder, hidden_dims=update_block_hidden_dims
)

# We use the largest scale hidden_dims of update_block to get the predicted depth
depth_head = kwargs.pop("depth_head", None) or FlowHead(
# We use the largest scale hidden_dims of update_block to get the predicted disparity
disparity_head = kwargs.pop("disparity_head", None) or FlowHead(
in_channels=update_block_hidden_dims[0],
hidden_size=flow_head_hidden_size,
)
Expand All @@ -598,7 +611,7 @@ def _raft_stereo(
corr_pyramid=corr_pyramid,
corr_block=corr_block,
update_block=update_block,
depth_head=depth_head,
disparity_head=disparity_head,
mask_predictor=mask_predictor,
slow_fast=slow_fast,
**kwargs, # not really needed, all params should be consumed by now
Expand Down