Skip to content

Commit 6385c20

Browse files
authored
Merge pull request #192 from fractal-analytics-platform/132_bounding_box
ref #132, add preliminary implementation of bounding box
2 parents 4eee8ce + 407500a commit 6385c20

6 files changed

+678
-395
lines changed

fractal_tasks_core/cellpose_segmentation.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,23 @@
2626
import anndata as ad
2727
import dask.array as da
2828
import numpy as np
29+
import pandas as pd
2930
import zarr
31+
from anndata.experimental import write_elem
3032
from cellpose import models
3133
from cellpose.core import use_gpu
3234

3335
import fractal_tasks_core
3436
from fractal_tasks_core.lib_pyramid_creation import build_pyramid
37+
from fractal_tasks_core.lib_regions_of_interest import (
38+
array_to_bounding_box_table,
39+
)
3540
from fractal_tasks_core.lib_regions_of_interest import (
3641
convert_ROI_table_to_indices,
3742
)
43+
from fractal_tasks_core.lib_remove_FOV_overlaps import (
44+
get_overlapping_pairs_3D,
45+
)
3846
from fractal_tasks_core.lib_zattrs_utils import extract_zyx_pixel_sizes
3947
from fractal_tasks_core.lib_zattrs_utils import rescale_datasets
4048

@@ -121,6 +129,7 @@ def cellpose_segmentation(
121129
flow_threshold: float = 0.4,
122130
model_type: str = "nuclei",
123131
ROI_table_name: str = "FOV_ROI_table",
132+
bounding_box_ROI_table_name: str = None,
124133
) -> Dict[str, Any]:
125134
"""
126135
Example inputs:
@@ -136,8 +145,8 @@ def cellpose_segmentation(
136145
# Set input path
137146
if len(input_paths) > 1:
138147
raise NotImplementedError
139-
in_path = input_paths[0]
140-
zarrurl = (in_path.parent.resolve() / component).as_posix() + "/"
148+
in_path = input_paths[0].parent
149+
zarrurl = (in_path.resolve() / component).as_posix() + "/"
141150
logger.info(zarrurl)
142151

143152
# Read useful parameters from metadata
@@ -166,6 +175,9 @@ def cellpose_segmentation(
166175
f"{zarrurl}.zattrs", level=0
167176
)
168177

178+
actual_res_pxl_sizes_zyx = extract_zyx_pixel_sizes(
179+
f"{zarrurl}.zattrs", level=labeling_level
180+
)
169181
# Create list of indices for 3D FOVs spanning the entire Z direction
170182
list_indices = convert_ROI_table_to_indices(
171183
ROI_table,
@@ -308,6 +320,10 @@ def cellpose_segmentation(
308320

309321
# Iterate over ROIs
310322
num_ROIs = len(list_indices)
323+
324+
if bounding_box_ROI_table_name:
325+
bbox_dataframe_list = []
326+
311327
logger.info(f"[{well_id}] Now starting loop over {num_ROIs} ROIs")
312328
for i_ROI, indices in enumerate(list_indices):
313329
# Define region
@@ -353,6 +369,20 @@ def cellpose_segmentation(
353369
f"but dtype={label_dtype}"
354370
)
355371

372+
if bounding_box_ROI_table_name:
373+
374+
bbox_df = array_to_bounding_box_table(
375+
fov_mask, actual_res_pxl_sizes_zyx
376+
)
377+
378+
bbox_dataframe_list.append(bbox_df)
379+
380+
overlap_list = []
381+
for df in bbox_dataframe_list:
382+
overlap_list.append(
383+
get_overlapping_pairs_3D(df, full_res_pxl_sizes_zyx)
384+
)
385+
356386
# Compute and store 0-th level to disk
357387
da.array(fov_mask).to_zarr(
358388
url=mask_zarr,
@@ -378,6 +408,23 @@ def cellpose_segmentation(
378408

379409
logger.info(f"[{well_id}] End building pyramids, exit")
380410

411+
if bounding_box_ROI_table_name:
412+
logger.info(f"[{well_id}] Writing bounding box table, exit")
413+
# Concatenate all FOV dataframes
414+
df_well = pd.concat(bbox_dataframe_list, axis=0, ignore_index=True)
415+
df_well.index = df_well.index.astype(str)
416+
# Convert all to float (warning: some would be int, in principle)
417+
bbox_dtype = np.float32
418+
df_well = df_well.astype(bbox_dtype)
419+
# Convert to anndata
420+
bbox_table = ad.AnnData(df_well, dtype=bbox_dtype)
421+
# Write to zarr group
422+
group_tables = zarr.group(f"{in_path}/{component}/tables/")
423+
write_elem(group_tables, bounding_box_ROI_table_name, bbox_table)
424+
logger.info(
425+
f"[{in_path}/{component}/tables/{bounding_box_ROI_table_name}"
426+
)
427+
381428
return {}
382429

383430

@@ -402,6 +449,7 @@ class TaskArguments(BaseModel):
402449
flow_threshold: float = 0.4
403450
model_type: str = "nuclei"
404451
ROI_table_name: str = "FOV_ROI_table"
452+
bounding_box_ROI_table_name: Optional[str] = None
405453

406454
run_fractal_task(
407455
task_function=cellpose_segmentation, TaskArgsModel=TaskArguments

fractal_tasks_core/lib_regions_of_interest.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,41 @@ def _inspect_ROI_table(
285285
print("Something went wrong in convert_ROI_table_to_indices\n", str(e))
286286

287287
return df
288+
289+
290+
def array_to_bounding_box_table(
291+
mask_array: np.ndarray, pxl_sizes_zyx: List[float]
292+
) -> pd.DataFrame:
293+
294+
"""
295+
Description
296+
297+
:param dummy: this is just a placeholder
298+
:type dummy: int
299+
"""
300+
301+
labels = np.unique(mask_array)
302+
labels = labels[labels > 0]
303+
elem_list = []
304+
for label in labels:
305+
label_match = np.where(mask_array == label)
306+
zmin, ymin, xmin = np.min(label_match, axis=1) * pxl_sizes_zyx
307+
zmax, ymax, xmax = (np.max(label_match, axis=1) + 1) * pxl_sizes_zyx
308+
309+
length_x = xmax - xmin
310+
length_y = ymax - ymin
311+
length_z = zmax - zmin
312+
elem_list.append((xmin, ymin, zmin, length_x, length_y, length_z))
313+
314+
df_columns = [
315+
"x_micrometer",
316+
"y_micrometer",
317+
"z_micrometer",
318+
"len_x_micrometer",
319+
"len_y_micrometer",
320+
"len_z_micrometer",
321+
]
322+
323+
ann_df = pd.DataFrame(np.array(elem_list), columns=df_columns)
324+
325+
return ann_df

fractal_tasks_core/lib_remove_FOV_overlaps.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,27 @@ def is_overlapping_2D(box1, box2, tol=0):
4848
return overlap_x and overlap_y
4949

5050

51+
def is_overlapping_3D(box1, box2, tol=0):
52+
"""
53+
Based on https://stackoverflow.com/a/70023212/19085332
54+
55+
box: (xmin, ymin, zmin, xmax, ymax, zmax)
56+
57+
:param dummy: this is just a placeholder
58+
:type dummy: int
59+
"""
60+
overlap_x = is_overlapping_1D(
61+
[box1[0], box1[3]], [box2[0], box2[3]], tol=tol
62+
)
63+
overlap_y = is_overlapping_1D(
64+
[box1[1], box1[4]], [box2[1], box2[4]], tol=tol
65+
)
66+
overlap_z = is_overlapping_1D(
67+
[box1[2], box1[5]], [box2[2], box2[5]], tol=tol
68+
)
69+
return overlap_x and overlap_y and overlap_z
70+
71+
5172
def get_overlapping_pair(tmp_df, tol=0):
5273
"""
5374
Description
@@ -66,6 +87,49 @@ def get_overlapping_pair(tmp_df, tol=0):
6687
return False
6788

6889

90+
def get_overlapping_pairs_3D(tmp_df, pixel_sizes):
91+
"""
92+
Description
93+
94+
:param dummy: this is just a placeholder
95+
:type dummy: int
96+
"""
97+
# NOTE: here we use positional indices (i.e. starting from 0)
98+
tol = 10e-10
99+
if tol > min(pixel_sizes) / 1e3:
100+
raise Exception(f"{tol=} but {pixel_sizes=}")
101+
new_tmp_df = tmp_df.copy()
102+
103+
new_tmp_df["x_micrometer_max"] = (
104+
new_tmp_df["x_micrometer"] + new_tmp_df["len_x_micrometer"]
105+
)
106+
new_tmp_df["y_micrometer_max"] = (
107+
new_tmp_df["y_micrometer"] + new_tmp_df["len_y_micrometer"]
108+
)
109+
new_tmp_df["z_micrometer_max"] = (
110+
new_tmp_df["z_micrometer"] + new_tmp_df["len_z_micrometer"]
111+
)
112+
113+
new_tmp_df.drop(labels=["len_x_micrometer"], axis=1, inplace=True)
114+
new_tmp_df.drop(labels=["len_y_micrometer"], axis=1, inplace=True)
115+
new_tmp_df.drop(labels=["len_z_micrometer"], axis=1, inplace=True)
116+
num_lines = len(new_tmp_df.index)
117+
overlapping_list = []
118+
# pos_ind_1 and pos_ind_2 are labels value
119+
for pos_ind_1 in range(num_lines):
120+
for pos_ind_2 in range(pos_ind_1):
121+
if is_overlapping_3D(
122+
new_tmp_df.iloc[pos_ind_1], new_tmp_df.iloc[pos_ind_2], tol=tol
123+
):
124+
# we accumulate tuples of overlapping labels
125+
overlapping_list.append((pos_ind_1, pos_ind_2))
126+
if len(overlapping_list) > 0:
127+
raise ValueError(
128+
f"{overlapping_list} " f"List of pair of bounding box overlaps"
129+
)
130+
return overlapping_list
131+
132+
69133
def remove_FOV_overlaps(df):
70134
"""
71135
Description

0 commit comments

Comments
 (0)