Skip to content

Commit 585d050

Browse files
committed
Move relabeling into segment_ROI function to fix #785
1 parent 615e6c7 commit 585d050

File tree

2 files changed

+37
-23
lines changed

2 files changed

+37
-23
lines changed

fractal_tasks_core/tasks/cellpose_segmentation.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -73,20 +73,26 @@
7373

7474
def segment_ROI(
7575
x: np.ndarray,
76+
num_labels_tot: dict[str, int],
7677
model: models.CellposeModel = None,
7778
do_3D: bool = True,
7879
channels: list[int] = [0, 0],
7980
diameter: float = 30.0,
8081
normalize: CellposeCustomNormalizer = CellposeCustomNormalizer(),
8182
normalize2: Optional[CellposeCustomNormalizer] = None,
8283
label_dtype: Optional[np.dtype] = None,
84+
relabeling: bool = True,
8385
advanced_cellpose_model_params: CellposeModelParams = CellposeModelParams(), # noqa: E501
8486
) -> np.ndarray:
8587
"""
8688
Internal function that runs Cellpose segmentation for a single ROI.
8789
8890
Args:
8991
x: 4D numpy array.
92+
num_labels_tot: Number of labels already in total image. Used for
93+
relabeling purposes. Using a dict to have a mutable object that
94+
can be edited from within the function without having to be passed
95+
back through the masked_loading_wrapper.
9096
model: An instance of `models.CellposeModel`.
9197
do_3D: If `True`, cellpose runs in 3D mode: runs on xy, xz & yz planes,
9298
then averages the flows.
@@ -107,6 +113,7 @@ def segment_ROI(
107113
normalized with default settings, both channels need to be
108114
normalized with default settings.
109115
label_dtype: Label images are cast into this `np.dtype`.
116+
relabeling: Whether relabeling based on num_labels_tot is performed.
110117
advanced_cellpose_model_params: Advanced Cellpose model parameters
111118
that are passed to the Cellpose `model.eval` method.
112119
"""
@@ -163,6 +170,23 @@ def segment_ROI(
163170
f" {advanced_cellpose_model_params.flow_threshold=}"
164171
)
165172

173+
# Shift labels and update relabeling counters
174+
if relabeling:
175+
num_labels_roi = np.max(mask)
176+
mask[mask > 0] += num_labels_tot["num_labels_tot"]
177+
num_labels_tot["num_labels_tot"] += num_labels_roi
178+
179+
# Write some logs
180+
logger.info(f"ROI had {num_labels_roi=}, {num_labels_tot=}")
181+
182+
# Check that total number of labels is under control
183+
if num_labels_tot["num_labels_tot"] > np.iinfo(label_dtype).max:
184+
raise ValueError(
185+
"ERROR in re-labeling:"
186+
f"Reached {num_labels_tot} labels, "
187+
f"but dtype={label_dtype}"
188+
)
189+
166190
return mask.astype(label_dtype)
167191

168192

@@ -438,8 +462,7 @@ def cellpose_segmentation(
438462
logger.info(f"{data_zyx_c2.chunks}")
439463

440464
# Counters for relabeling
441-
if relabeling:
442-
num_labels_tot = 0
465+
num_labels_tot = {"num_labels_tot": 0}
443466

444467
# Iterate over ROIs
445468
num_ROIs = len(list_indices)
@@ -485,13 +508,15 @@ def cellpose_segmentation(
485508

486509
# Prepare keyword arguments for segment_ROI function
487510
kwargs_segment_ROI = dict(
511+
num_labels_tot=num_labels_tot,
488512
model=model,
489513
channels=channels,
490514
do_3D=do_3D,
491515
label_dtype=label_dtype,
492516
diameter=diameter_level0 / coarsening_xy**level,
493517
normalize=channel.normalize,
494518
normalize2=channel2.normalize,
519+
relabeling=relabeling,
495520
advanced_cellpose_model_params=advanced_cellpose_model_params,
496521
)
497522

@@ -515,23 +540,6 @@ def cellpose_segmentation(
515540
preprocessing_kwargs=preprocessing_kwargs,
516541
)
517542

518-
# Shift labels and update relabeling counters
519-
if relabeling:
520-
num_labels_roi = np.max(new_label_img)
521-
new_label_img[new_label_img > 0] += num_labels_tot
522-
num_labels_tot += num_labels_roi
523-
524-
# Write some logs
525-
logger.info(f"ROI {indices}, {num_labels_roi=}, {num_labels_tot=}")
526-
527-
# Check that total number of labels is under control
528-
if num_labels_tot > np.iinfo(label_dtype).max:
529-
raise ValueError(
530-
"ERROR in re-labeling:"
531-
f"Reached {num_labels_tot} labels, "
532-
f"but dtype={label_dtype}"
533-
)
534-
535543
if output_ROI_table:
536544
bbox_df = array_to_bounding_box_table(
537545
new_label_img,

tests/tasks/test_workflows_cellpose_segmentation.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -653,10 +653,6 @@ def test_cellpose_within_masked_bb_with_overlap(
653653
patched_cellpose_eval,
654654
)
655655

656-
# Setup caplog fixture, see
657-
# https://docs.pytest.org/en/stable/how-to/logging.html#caplog-fixture
658-
caplog.set_level(logging.WARNING)
659-
660656
# Use pre-made 3D zarr
661657
zarr_dir = tmp_path / "tmp_out/"
662658
zarr_urls = prepare_3D_zarr(str(zarr_dir), zenodo_zarr)
@@ -717,6 +713,16 @@ def test_cellpose_within_masked_bb_with_overlap(
717713
# mock testing ignores the input
718714
# assert np.max(secondary_segmentation) == 4
719715

716+
# Ensure that labels stay in correct proportions => relabeling doesn't do
717+
# reassignment
718+
label1 = np.sum(secondary_segmentation == 1)
719+
label3 = np.sum(secondary_segmentation == 3)
720+
label5 = np.sum(secondary_segmentation == 5)
721+
label7 = np.sum(secondary_segmentation == 7)
722+
assert label1 == label3
723+
assert label1 == label5
724+
assert label1 == label7
725+
720726

721727
def test_workflow_with_per_FOV_labeling_via_script(
722728
tmp_path: Path,

0 commit comments

Comments
 (0)