Skip to content

Commit ed2205c

Browse files
committed
BROKEN stash initial work on new Channel model (ref #386)
1 parent 9cc9078 commit ed2205c

File tree

3 files changed

+67
-48
lines changed

3 files changed

+67
-48
lines changed

fractal_tasks_core/lib_channels.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,30 @@
1515
"""
1616
import logging
1717
from typing import Any
18-
from typing import Dict
1918
from typing import List
20-
from typing import Sequence
19+
from typing import Optional
2120

2221
import zarr
22+
from pydantic import BaseModel
23+
24+
25+
class ChannelWindow(BaseModel):
26+
min: str
27+
max: str
28+
start: Optional[str]
29+
end: Optional[str]
30+
31+
32+
class Channel(BaseModel):
33+
wavelength_id: str
34+
label: Optional[str]
35+
index: Optional[int]
36+
active: bool = True
37+
coefficient: int = 1
38+
colormap: Optional[str]
39+
family: str = "linear"
40+
inverted: bool = False
41+
window: Optional[ChannelWindow]
2342

2443

2544
class ChannelNotFoundError(ValueError):
@@ -31,19 +50,11 @@ class ChannelNotFoundError(ValueError):
3150
pass
3251

3352

34-
def validate_allowed_channel_input(allowed_channels: Sequence[Dict[str, Any]]):
53+
def validate_allowed_channel_input(allowed_channels: List[Channel]):
3554
"""
36-
Check that (1) each channel has a wavelength_id key, and (2) the
37-
wavelength_id values are unique.
55+
Check that the `wavelength_id` values are unique across channels
3856
"""
39-
try:
40-
wavelength_ids = [c["wavelength_id"] for c in allowed_channels]
41-
except KeyError as e:
42-
raise KeyError(
43-
"Missing wavelength_id key in some channel.\n"
44-
f"{allowed_channels=}\n"
45-
f"Original error: {str(e)}"
46-
)
57+
wavelength_ids = [c.wavelength_id for c in allowed_channels]
4758
if len(set(wavelength_ids)) < len(wavelength_ids):
4859
raise ValueError(
4960
f"Non-unique wavelength_id's in {wavelength_ids}\n"
@@ -73,16 +84,16 @@ def check_well_channel_labels(*, well_zarr_path: str) -> None:
7384

7485
# For each pair of channel-labels lists, verify they do not overlap
7586
for ind_1, channels_1 in enumerate(list_of_channel_lists):
76-
labels_1 = set([c["label"] for c in channels_1])
87+
labels_1 = set([c.label for c in channels_1])
7788
for ind_2 in range(ind_1):
7889
channels_2 = list_of_channel_lists[ind_2]
79-
labels_2 = set([c["label"] for c in channels_2])
90+
labels_2 = set([c.label for c in channels_2])
8091
intersection = labels_1 & labels_2
8192
if intersection:
8293
hint = (
83-
"Are you parsing fields of view into separate OME-Zarr"
84-
" images? This could lead to non-unique channel labels"
85-
", and then could be the reason of the error"
94+
"Are you parsing fields of view into separate OME-Zarr "
95+
"images? This could lead to non-unique channel labels, "
96+
"and then could be the reason of the error"
8697
)
8798
raise ValueError(
8899
"Non-unique channel labels\n"
@@ -92,7 +103,7 @@ def check_well_channel_labels(*, well_zarr_path: str) -> None:
92103

93104
def get_channel_from_image_zarr(
94105
*, image_zarr_path: str, label: str = None, wavelength_id: str = None
95-
) -> Dict[str, Any]:
106+
) -> Channel:
96107
"""
97108
Extract a channel from OME-NGFF zarr attributes
98109
@@ -112,20 +123,23 @@ def get_channel_from_image_zarr(
112123
return channel
113124

114125

115-
def get_omero_channel_list(*, image_zarr_path: str) -> List[Dict[str, Any]]:
126+
def get_omero_channel_list(*, image_zarr_path: str) -> List[Channel]:
116127
"""
117128
Extract the list of channels from OME-NGFF zarr attributes
118129
119130
:param image_zarr_path: Path to an OME-NGFF image zarr group
120131
:returns: A list of channel dictionaries
121132
"""
122133
group = zarr.open_group(image_zarr_path, mode="r+")
123-
return group.attrs["omero"]["channels"]
134+
channels_dicts = group.attrs["omero"]["channels"]
135+
# FIXME what is the type of channels_dicts??
136+
channels = [Channel(**c) for c in channels_dicts]
137+
return channels
124138

125139

126140
def get_channel_from_list(
127-
*, channels: Sequence[Dict], label: str = None, wavelength_id: str = None
128-
) -> Dict[str, Any]:
141+
*, channels: List[Channel], label: str = None, wavelength_id: str = None
142+
) -> Channel:
129143
"""
130144
Find matching channel in a list
131145
@@ -147,16 +161,14 @@ def get_channel_from_list(
147161
matching_channels = [
148162
c
149163
for c in channels
150-
if (
151-
c["label"] == label and c["wavelength_id"] == wavelength_id
152-
)
164+
if (c.label == label and c.wavelength_id == wavelength_id)
153165
]
154166
else:
155-
matching_channels = [c for c in channels if c["label"] == label]
167+
matching_channels = [c for c in channels if c.label == label]
156168
else:
157169
if wavelength_id:
158170
matching_channels = [
159-
c for c in channels if c["wavelength_id"] == wavelength_id
171+
c for c in channels if c.wavelength_id == wavelength_id
160172
]
161173
else:
162174
raise ValueError(
@@ -178,16 +190,16 @@ def get_channel_from_list(
178190
raise ValueError(f"Inconsistent set of channels: {channels}")
179191

180192
channel = matching_channels[0]
181-
channel["index"] = channels.index(channel)
193+
channel.index = channels.index(channel)
182194
return channel
183195

184196

185197
def define_omero_channels(
186198
*,
187-
channels: Sequence[Dict[str, Any]],
199+
channels: List[Channel],
188200
bit_depth: int,
189201
label_prefix: str = None,
190-
) -> List[Dict[str, Any]]:
202+
) -> List[dict[str, Any]]:
191203
"""
192204
Update a channel list to use it in the OMERO/channels metadata
193205
@@ -211,11 +223,11 @@ def define_omero_channels(
211223
default_colormaps = ["00FFFF", "FF00FF", "FFFF00"]
212224

213225
for channel in channels:
214-
wavelength_id = channel["wavelength_id"]
226+
wavelength_id = channel.wavelength_id
215227

216228
# Always set a label
217229
try:
218-
label = channel["label"]
230+
label = channel.label
219231
except KeyError:
220232
default_label = wavelength_id
221233
if label_prefix:
@@ -227,7 +239,7 @@ def define_omero_channels(
227239

228240
# Set colormap attribute. If not specificed, use the default ones (for
229241
# the first three channels) or gray
230-
colormap = channel.get("colormap", None)
242+
colormap = channel.colormap
231243
if colormap is None:
232244
try:
233245
colormap = default_colormaps.pop()
@@ -239,7 +251,7 @@ def define_omero_channels(
239251
"min": 0,
240252
"max": 2**bit_depth - 1,
241253
}
242-
if "start" in channel.keys() and "end" in channel.keys():
254+
if "start" in channel.dict().keys() and "end" in channel.dict().keys():
243255
window["start"] = channel["start"]
244256
window["end"] = channel["end"]
245257

fractal_tasks_core/tasks/create_ome_zarr.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
import pandas as pd
2525
import zarr
2626
from anndata.experimental import write_elem
27+
from devtools import debug
2728

2829
import fractal_tasks_core
30+
from fractal_tasks_core.lib_channels import Channel
2931
from fractal_tasks_core.lib_channels import check_well_channel_labels
3032
from fractal_tasks_core.lib_channels import define_omero_channels
3133
from fractal_tasks_core.lib_channels import validate_allowed_channel_input
@@ -51,7 +53,7 @@ def create_ome_zarr(
5153
metadata: Dict[str, Any],
5254
image_extension: str = "tif",
5355
image_glob_patterns: Optional[list[str]] = None,
54-
allowed_channels: Sequence[Dict[str, Any]],
56+
allowed_channels: List[Channel],
5557
num_levels: int = 2,
5658
coarsening_xy: int = 2,
5759
metadata_table: str = "mrf_mlf",
@@ -108,6 +110,9 @@ def create_ome_zarr(
108110
dict_plate_prefixes: Dict[str, Any] = {}
109111

110112
# Preliminary checks on allowed_channels argument
113+
allowed_channels_raw = allowed_channels.copy()
114+
allowed_channels = [Channel(**c) for c in allowed_channels_raw]
115+
debug(allowed_channels)
111116
validate_allowed_channel_input(allowed_channels)
112117

113118
for in_path_str in input_paths:
@@ -188,7 +193,7 @@ def create_ome_zarr(
188193

189194
# Check that all channels are in the allowed_channels
190195
allowed_wavelength_ids = [
191-
channel["wavelength_id"] for channel in allowed_channels
196+
channel.wavelength_id for channel in allowed_channels
192197
]
193198
if not set(actual_wavelength_ids).issubset(set(allowed_wavelength_ids)):
194199
msg = "ERROR in create_ome_zarr\n"
@@ -201,7 +206,7 @@ def create_ome_zarr(
201206
actual_channels = [
202207
channel
203208
for channel in allowed_channels
204-
if channel["wavelength_id"] in actual_wavelength_ids
209+
if channel.wavelength_id in actual_wavelength_ids
205210
]
206211

207212
zarrurls: Dict[str, List[str]] = {"plate": [], "well": [], "image": []}

tests/test_unit_channels_addressing.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,36 @@
33

44
from devtools import debug
55

6+
from fractal_tasks_core.lib_channels import Channel
67
from fractal_tasks_core.lib_channels import get_channel_from_list
78

89

910
def test_get_channel(testdata_path: Path):
1011
with (testdata_path / "omero/channels_list.json").open("r") as f:
11-
omero_channels = json.load(f)
12+
omero_channels_dict = json.load(f)
13+
omero_channels = [Channel(**c) for c in omero_channels_dict]
1214
debug(omero_channels)
1315

1416
channel = get_channel_from_list(channels=omero_channels, label="label_1")
1517
debug(channel)
16-
assert channel["label"] == "label_1"
17-
assert channel["wavelength_id"] == "wavelength_id_1"
18-
assert channel["index"] == 0
18+
assert channel.label == "label_1"
19+
assert channel.wavelength_id == "wavelength_id_1"
20+
assert channel.index == 0
1921

2022
channel = get_channel_from_list(
2123
channels=omero_channels, wavelength_id="wavelength_id_2"
2224
)
2325
debug(channel)
24-
assert channel["label"] == "label_2"
25-
assert channel["wavelength_id"] == "wavelength_id_2"
26-
assert channel["index"] == 1
26+
assert channel.label == "label_2"
27+
assert channel.wavelength_id == "wavelength_id_2"
28+
assert channel.index == 1
2729

2830
channel = get_channel_from_list(
2931
channels=omero_channels,
3032
label="label_2",
3133
wavelength_id="wavelength_id_2",
3234
)
3335
debug(channel)
34-
assert channel["label"] == "label_2"
35-
assert channel["wavelength_id"] == "wavelength_id_2"
36-
assert channel["index"] == 1
36+
assert channel.label == "label_2"
37+
assert channel.wavelength_id == "wavelength_id_2"
38+
assert channel.index == 1

0 commit comments

Comments
 (0)