Skip to content

scipy.ndimage.find_objects #102201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
kazemSafari opened this issue May 24, 2023 · 11 comments
Open

scipy.ndimage.find_objects #102201

kazemSafari opened this issue May 24, 2023 · 11 comments
Assignees
Labels
feature A request for a proper, new feature. module: nn Related to torch.nn needs research We need to decide whether or not this merits inclusion, based on research world triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@kazemSafari
Copy link

kazemSafari commented May 24, 2023

🚀 The feature, motivation and pitch

This function a basic building block of any biomedical image analysis application.
It gives a list of tuple of slices of coordinates of labelled objects/cells within a mask image of dtype Uint16 or Uint8, assuming the image background is 0, and the labelled objects go from 1, 2, ..., max_label.

I was wondering it is possible to implement it in torch C++ using a simple TensorIterator.

Basically the simplest case would be it takes a 2D tensor of size (H, W) as input and
outputs a tensor of slices of size (N, 2) where N is the number of objects, and each row
is [slice(start,end,step), slice(start,end,step)].

Alternatives

The implementation in C numpy can be found here:
https://github.com/scipy/scipy/blob/v1.10.1/scipy/ndimage/src/ni_measure.c

which uses Iterators defined here:
https://github.com/scipy/scipy/blob/v1.10.1/scipy/ndimage/src/ni_support.h

Additional context

Can it also be extended to allow extract objects from a tensor of dimension (B, C, W, H) where B is the batch size, C the number of
channels and W is the width and H is the height.

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

@soulitzer soulitzer added feature A request for a proper, new feature. module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 25, 2023
@github-project-automation github-project-automation bot moved this to To pick up in torch.nn/optim May 26, 2023
@jbschlosser jbschlosser added the needs research We need to decide whether or not this merits inclusion, based on research world label May 26, 2023
@albanD
Copy link
Collaborator

albanD commented Jun 5, 2023

cc @rgommers in case you have ideas!

@kazemSafari
Copy link
Author

kazemSafari commented Jun 5, 2023

@soulitzer @albanD and @jbschlosser thanks for the followup and showing interest.
In summary they use a Numpy Iterator object in the base C code to iterate over all entries of the numpy array and record the smallest and largest index of each object in each direction/dimension in a numpy array pointer called regions, while knowing how many objects there is, and update regions pointer using multi-threading in numpy.

To add more context, there usually is a labelled mask_image and the corresponding intensity_image, say of the nucleus/cytoplasm/nucleoli/actin/mitocondria channel of a florescent microscopy screen as just one of the many potential applications.

So, I was also wondering if you could extend this function to a class instead, that has another method, called extract_objects, that allows extracting the objects from the corresponding image using the bounding boxes found by find_objects applied to the mask_image.

Because different types of measurements, called morphological profiling, from those objects like mean_intensity, area, texture, and etc are needed. Those measurements are used in pharmaceutical/biomedical imaging industry to figure out which drug could work to treat a condition say breast-cancer.

This class together with the nestedTensor API (because the dimensions of cells, 300-1000 of them within a single image, are not the same, and zero-padding them to the same size will requires a huge amount of RAM) can help build a Dataset + Model Api for this type of application.

@soulitzer
Copy link
Contributor

Thanks for the proposal. I was wondering how this fits into a image segmentation workflow in general, specifically why not have whatever image segmentation model you are running to just generate the masks separately - isn't that how its usually done?

@kazemSafari
Copy link
Author

kazemSafari commented Jun 5, 2023

@soulitzer, thanks for your reply.
Please note that in this context in a biological experiment, there are at least about ~17000 images generated at 20X magnification per an assay plate/single biological experiment and we have hundreds of these plates. For example, in a 384-well plate is a (16, 24) rectangle where each well which is treated with a compound is imaged at 9 different fields of view from its center and has 5 channels,16249*5 = 17280 many images.

The problem in general is that segmentation of biomedical imaging is extremely challenging. Mostly, thresholding techniques are used because they are readily available in opencv, SimpleITK, and Skimage. Also, one of the best semi-nueral network cellular segmentation models to date is Cellpose. However, after getting a mask_image, one needs to label the objects with the image (It is usually saved as a separate file at this stage). Then one needs to take measurements from individual entities/cells within each image, per plate. However, it is not always feasible to combine generating masks and taking the measurements step, and do them simultaneously.

For example, Cellpose (https://github.com/MouseLand/cellpose) does not provide batch processing of image the same way pytorch API does. It uses pytorch for generating a mask which can be parallelized but for labelling those masks it uses other models which have not been parallelized over batches of images. It can process one image at a time.

Therefore, after mask generation, we would like to take measurements from individual objects within a single image. A function like find_objects comes into play and allows one to extraction the objects/cells from each image and we use region_props class from skimage.measure._regionprops.RegionProperties class (found in https://github.com/scikit-image/scikit-image/blob/v0.21.0/skimage/measure/_regionprops.py#L1046-L1329 ) to loop over those regions and extraction those measurements. Also have a look at the implementation of regionprops function within the same file to get an idea.

This process is only done on the cpu and is extremely slow (can take up to ~ 3hours per plate, 17000 images). I have not been able to find any GPU implementation so far.


However, using torch dataset/dataloader api with a custom collate function one can:

  1. load multiple images each with multiple channels into memory.
  2. extract all the cells within a single image from all channels (assuming the label of each objects in all the different channels is identical, even though the regions might be bigger or smaller. For example, the cytoplsam of a cell contains its nucleus and nucleoli).
  3. combine all into a single nested tensor using a custom collate function.
  4. Assuming all the different functions used for extracting features is written in pytorch
    (which they are not at the moment, but can be easily written in pytorch mostly using python) one can then extract all the necessary features from all objects/cell-compartments from all images in a batch (in batch wise fashion).

But the key to all this to make this all possible is having a modified scipy.ndimage.find_objects that can handle a batch of images
with multiple channels written in pytorch.

Hope that helped.

@soulitzer
Copy link
Contributor

Thanks for the context. If you do not need to backprop through find_objects with autograd, it seems that the workaround here is to just convert to numpy and back here.

@kazemSafari
Copy link
Author

kazemSafari commented Jun 5, 2023

That is correct.

But one still will have to do a triple for loop in python ( to extract from each image, each object/cell, from each channel) which is very slow because of GIL. I am using the multiprocessing module over the images but it is still a slow cpu type implementation.

Also, the find_objects function only gives the coordinates of the objects. It does not extract them in a python list/ or a C++ vector. So I can't readily convert them to a tensor.

@kazemSafari
Copy link
Author

kazemSafari commented Jun 6, 2023

@soulitzer @albanD @jbschlosser
I just found about torchvision.ops.masks_to_boxes (https://pytorch.org/vision/main/auto_examples/plot_repurposing_annotations.html#sphx-glr-download-auto-examples-plot-repurposing-annotations-py) which is exactly what i need and is equivalent to scipy.ndimage.find_objects.

But the only small issue is that pytorch and torchvision.io.read_image does not support tiff files format, np.uint16, yet.
The easy work around is to read the tiff file with tifffile or skimage.io imread method. Then convert it to a tensor of appropriate type torch.int16 for example.
It also support tensors on the GPU.
for my mask image the contain 418 different labelled objects i got the following results in seconds:
torch.Size([418, 4])
scipy.ndimage.find_object cpu: 0.008491992950439453
torchvision.ops.masks_to_boxes cpu: 2.4579997062683105
transfer to gpu time: 0.3227546215057373
torchvision.ops.masks_to_boxes gpu first time: 1.5470712184906006
torchvision.ops.masks_to_boxes gpu second time: 0.1925187110900879

I just have 2 more requests.

  1. Can you add another function to torchvision.ops that labels a bolean mask_image file, same/similar as skimage.measure.label? Is there one that already exist in torch or torchvision?

  2. Given a mask_image file, after extracting those bounding_boxes, how can one handle extracting their intensity from a corresponding image.
    Is there already a functionality for it?

Thanks again for listening/reading.

@soulitzer
Copy link
Contributor

Oh interesting, I saw that too, but they seemed different to me because masks_to_boxes just computes the min/max values of a single provided mask, whereas find_objects takes a single image with more than one mask and extracts out the connected components.

Btw for those requests you may want to post to https://github.com/pytorch/vision instead?

@kazemSafari
Copy link
Author

Yeah sure. All three of them?!

It is slower in torch maybe due to that fact the pytorch does not support uint16 natively.
Just out of curiosity, why is that? is it because you guys do not like microscopes and medical imaging ...
Just kidding!

@rgommers
Copy link
Collaborator

It is slower in torch maybe due to that fact the pytorch does not support uint16 natively.

This is the feature request, with a number of image processing use cases and a thumbs up about supporting it now that the PyTorch 2.0 infrastructure makes it easier to do so: gh-58734.

Btw for those requests you may want to post to https://github.com/pytorch/vision instead?

Yeah sure. All three of them?!

I believe so, yes - all scipy.ndimage equivalents probably belong in torchvision.

Given the performance results you posted above, even the GPU version of torchvision.ops.masks_to_boxes is much slower than SciPy though, so there's probably not much point in using masks_to_boxes unless the performance can be improved.

@vadimkantorov
Copy link
Contributor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: nn Related to torch.nn needs research We need to decide whether or not this merits inclusion, based on research world triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: To pick up
Development

No branches or pull requests

6 participants