Skip to content

3D NMS and ROI Align #2337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

3D NMS and ROI Align #2337

wants to merge 1 commit into from

Conversation

mibaumgartner
Copy link

@mibaumgartner mibaumgartner commented Jun 22, 2020

3D data gains more and more popularity inside the deep learning community. As a consequence it would be great to have a unified 3D NMS and 3D ROI Align for future and current projects like MONAI . This PR aims at implementing those features :)

As a first step @Gregor1337 added the kernels from the medicaldetectiontoolkit (torch1.x branch) . There are still some things missing which need additional discussion.

Open questions:

  • Should this be added in here or is there a better place?
  • Should the 2D and 3D kernels be integrated into one function or should there be 2 functions (one for 2D and one for 3D)?
  • Currently, only the cuda kernels are implemented, should there also be a CPU implementation?

TODOs:

  • Update to current pytorch vision master
  • Unittests for the 3D use case

This should also address #1678 .

@mjorgecardoso @pfjaeger

@mibaumgartner mibaumgartner changed the title add nms 2d3d and roialign 3d, needs testing and adapting 3D NMS and ROI Align Jun 22, 2020
@fmassa
Copy link
Member

fmassa commented Jul 7, 2020

Hi,

Thanks for opening this PR and sorry for the delay in replying.

Before we move forward, it would be good to discuss about what is that we would want to support.
Images and volumes have its differences, and trying to support both can become confusing.
For example, we have started with torchvision 0.4 to support video, which could be seen as a 4d data type, like volumes, potentially leading to confusion.

Proposal: discuss on an issue first

I think it would be very valuable to open an issue in torchvision first so that we can discuss the potential scope of supporting volumes / medical data types first. This is something that I have considered in the past but never had the chance to think about it more deeply, so maybe now it's the good time :-)

If you could describe what types of operations would generally be needed, and what would be the differences wrt images / videos and how we could handle those differences, it would be great!

Also, if you could share some references of highly influential research papers on using those 3d operations it would be great.

My thinking is that we should have a clear scope of what we could potentially do before starting any work, so that we have a clear picture ahead of us.

About your open questions

Should this be added in here or is there a better place?
Should the 2D and 3D kernels be integrated into one function or should there be 2 functions (one for 2D and one for 3D)?

Those are great questions and should be discussed in an issue, so that we can have a full picture of everything.

Currently, only the cuda kernels are implemented, should there also be a CPU implementation?

In torchvision we do require both CPU and CUDA implementations for the operators, so that they are accessible to everyone (not everyone has a GPU available)

@mjorgecardoso
Copy link

Hi Francisco,

Medical imaging is a huge field of research, with conferences such as ISMRM (5k+ attendees), MICCAI (2.5k+), ISBI (1.5k+). Volumetric neural network operations (convolutions, pooling, etc), are common and supported in PyTorch (see here https://pytorch.org/docs/master/generated/torch.nn.Conv3d.html).

Standard practice within pytorch is that tensors are shaped as [N,C,D,H,W], where DHW are depth, height, width. Here, are only requesting spatial 3D support. This can be encoded in the name of the function call, as per the pytorch standard.

In standard high dimentional systems (e.g. medical imaging data can often be 5D), dimensions are often described as Spatial [W,H,D]; Time [T]; Features [C], and concatenated in this order [W,H,D,T,C]. Under this construct, 2D video data will be [W,H,1,T,3], with 3 being the RGB/features of the video. Volumetric single-time-point multimodal imaging data will be [W,H,D,1,C], where C is the number of channels (MRI, CT, etc). Pytorch normally flips this around, meaning we should end up with a [C,T,D,H,W] dimensionality.

If you squeeze the dimensionality above, you end up with a [C,T,H,W] tensor for 2d Video, and [C,D,H,W] for 3D volumetric data, with the only difference being the meaning of a "convolution" or any operator on this domain.

@fmassa
Copy link
Member

fmassa commented Jul 7, 2020

Hi @mjorgecardoso

Thanks for your reply.
Can we continue the discussion on an issue, where you summarize what you have mentioned and what I'm about to discuss? I think this is something that would be useful to get other peoples feedback, and an issue might be more discoverable.

So, if we keep the squeezed-out formulation that you proposed

If you squeeze the dimensionality above, you end up with a [C,T,H,W] tensor for 2d Video, and [C,D,H,W] for 3D volumetric data, with the only difference being the meaning of a "convolution" or any operator on this domain.

This sounds reasonable to me, and is what we currently do for video models in torchvision. The problem is that the behavior of some operations (like roi_align and nms) is IMO dependent on the domain: if we are doing it for video, then we don't want to apply NMS over frames, and similarly for roi_align we just consider the temporal dimension as another batch dimension.
But for volumetric data, we want to compute nms and roi_align across the depth dimension as well, which can potentially create confusion.
I think the same applies to other more mainstream operations, like resize (or interpolate) -- for video we generally don't want to interpolate over the temporal dimension.

Now, I think this is something that could be fixed with the right conventions and proper documentation, but I would love your thoughts on how to avoid confusion here.

@mibaumgartner
Copy link
Author

mibaumgartner commented Jul 7, 2020

Thank you for your quick response @fmassa and @mjorgecardoso . I wasn't quite sure if i should hijack the old issue for our discussion, so I opened a new one where I summarised the main points of our discussion from here. #2402

yan12125 added a commit to yan12125/pytorch_roialign_nms that referenced this pull request Aug 12, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants