Skip to content

Commit c445edd

Browse files
authored
Merge pull request #1291 from skoudoro/update-data-per-streamline
Update data_per_streamline capabilities
2 parents 714d757 + c907b45 commit c445edd

File tree

2 files changed

+48
-16
lines changed

2 files changed

+48
-16
lines changed

nibabel/streamlines/tests/test_tractogram.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def make_dummy_streamline(nb_points):
8080
'mean_curvature': np.array([1.11], dtype='f4'),
8181
'mean_torsion': np.array([1.22], dtype='f4'),
8282
'mean_colors': np.array([1, 0, 0], dtype='f4'),
83+
'clusters_labels': np.array([0, 1], dtype='i4'),
8384
}
8485

8586
elif nb_points == 2:
@@ -92,6 +93,7 @@ def make_dummy_streamline(nb_points):
9293
'mean_curvature': np.array([2.11], dtype='f4'),
9394
'mean_torsion': np.array([2.22], dtype='f4'),
9495
'mean_colors': np.array([0, 1, 0], dtype='f4'),
96+
'clusters_labels': np.array([2, 3, 4], dtype='i4'),
9597
}
9698

9799
elif nb_points == 5:
@@ -104,6 +106,7 @@ def make_dummy_streamline(nb_points):
104106
'mean_curvature': np.array([3.11], dtype='f4'),
105107
'mean_torsion': np.array([3.22], dtype='f4'),
106108
'mean_colors': np.array([0, 0, 1], dtype='f4'),
109+
'clusters_labels': np.array([5, 6, 7, 8], dtype='i4'),
107110
}
108111

109112
return streamline, data_per_point, data_for_streamline
@@ -119,6 +122,7 @@ def setup_module():
119122
DATA['mean_curvature'] = []
120123
DATA['mean_torsion'] = []
121124
DATA['mean_colors'] = []
125+
DATA['clusters_labels'] = []
122126
for nb_points in [1, 2, 5]:
123127
data = make_dummy_streamline(nb_points)
124128
streamline, data_per_point, data_for_streamline = data
@@ -128,12 +132,14 @@ def setup_module():
128132
DATA['mean_curvature'].append(data_for_streamline['mean_curvature'])
129133
DATA['mean_torsion'].append(data_for_streamline['mean_torsion'])
130134
DATA['mean_colors'].append(data_for_streamline['mean_colors'])
135+
DATA['clusters_labels'].append(data_for_streamline['clusters_labels'])
131136

132137
DATA['data_per_point'] = {'colors': DATA['colors'], 'fa': DATA['fa']}
133138
DATA['data_per_streamline'] = {
134139
'mean_curvature': DATA['mean_curvature'],
135140
'mean_torsion': DATA['mean_torsion'],
136141
'mean_colors': DATA['mean_colors'],
142+
'clusters_labels': DATA['clusters_labels'],
137143
}
138144

139145
DATA['empty_tractogram'] = Tractogram(affine_to_rasmm=np.eye(4))
@@ -154,6 +160,7 @@ def setup_module():
154160
'mean_curvature': lambda: (e for e in DATA['mean_curvature']),
155161
'mean_torsion': lambda: (e for e in DATA['mean_torsion']),
156162
'mean_colors': lambda: (e for e in DATA['mean_colors']),
163+
'clusters_labels': lambda: (e for e in DATA['clusters_labels']),
157164
}
158165

159166
DATA['lazy_tractogram'] = LazyTractogram(
@@ -214,7 +221,10 @@ def test_per_array_dict_creation(self):
214221
data_dict = PerArrayDict(nb_streamlines, data_per_streamline)
215222
assert data_dict.keys() == data_per_streamline.keys()
216223
for k in data_dict.keys():
217-
assert_array_equal(data_dict[k], data_per_streamline[k])
224+
if isinstance(data_dict[k], np.ndarray) and np.all(
225+
data_dict[k].shape[0] == data_dict[k].shape
226+
):
227+
assert_array_equal(data_dict[k], data_per_streamline[k])
218228

219229
del data_dict['mean_curvature']
220230
assert len(data_dict) == len(data_per_streamline) - 1
@@ -224,7 +234,10 @@ def test_per_array_dict_creation(self):
224234
data_dict = PerArrayDict(nb_streamlines, data_per_streamline)
225235
assert data_dict.keys() == data_per_streamline.keys()
226236
for k in data_dict.keys():
227-
assert_array_equal(data_dict[k], data_per_streamline[k])
237+
if isinstance(data_dict[k], np.ndarray) and np.all(
238+
data_dict[k].shape[0] == data_dict[k].shape
239+
):
240+
assert_array_equal(data_dict[k], data_per_streamline[k])
228241

229242
del data_dict['mean_curvature']
230243
assert len(data_dict) == len(data_per_streamline) - 1
@@ -234,7 +247,10 @@ def test_per_array_dict_creation(self):
234247
data_dict = PerArrayDict(nb_streamlines, **data_per_streamline)
235248
assert data_dict.keys() == data_per_streamline.keys()
236249
for k in data_dict.keys():
237-
assert_array_equal(data_dict[k], data_per_streamline[k])
250+
if isinstance(data_dict[k], np.ndarray) and np.all(
251+
data_dict[k].shape[0] == data_dict[k].shape
252+
):
253+
assert_array_equal(data_dict[k], data_per_streamline[k])
238254

239255
del data_dict['mean_curvature']
240256
assert len(data_dict) == len(data_per_streamline) - 1
@@ -261,6 +277,7 @@ def test_extend(self):
261277
'mean_curvature': 2 * np.array(DATA['mean_curvature']),
262278
'mean_torsion': 3 * np.array(DATA['mean_torsion']),
263279
'mean_colors': 4 * np.array(DATA['mean_colors']),
280+
'clusters_labels': 5 * np.array(DATA['clusters_labels'], dtype=object),
264281
}
265282
sdict2 = PerArrayDict(len(DATA['tractogram']), new_data)
266283

@@ -284,7 +301,8 @@ def test_extend(self):
284301
'mean_curvature': 2 * np.array(DATA['mean_curvature']),
285302
'mean_torsion': 3 * np.array(DATA['mean_torsion']),
286303
'mean_colors': 4 * np.array(DATA['mean_colors']),
287-
'other': 5 * np.array(DATA['mean_colors']),
304+
'clusters_labels': 5 * np.array(DATA['clusters_labels'], dtype=object),
305+
'other': 6 * np.array(DATA['mean_colors']),
288306
}
289307
sdict2 = PerArrayDict(len(DATA['tractogram']), new_data)
290308

@@ -305,6 +323,7 @@ def test_extend(self):
305323
'mean_curvature': 2 * np.array(DATA['mean_curvature']),
306324
'mean_torsion': 3 * np.array(DATA['mean_torsion']),
307325
'mean_colors': 4 * np.array(DATA['mean_torsion']),
326+
'clusters_labels': 5 * np.array(DATA['clusters_labels'], dtype=object),
308327
}
309328
sdict2 = PerArrayDict(len(DATA['tractogram']), new_data)
310329
with pytest.raises(ValueError):
@@ -441,7 +460,10 @@ def test_lazydict_creation(self):
441460
assert is_lazy_dict(data_dict)
442461
assert data_dict.keys() == expected_keys
443462
for k in data_dict.keys():
444-
assert_array_equal(list(data_dict[k]), list(DATA['data_per_streamline'][k]))
463+
if isinstance(data_dict[k], np.ndarray) and np.all(
464+
data_dict[k].shape[0] == data_dict[k].shape
465+
):
466+
assert_array_equal(list(data_dict[k]), list(DATA['data_per_streamline'][k]))
445467

446468
assert len(data_dict) == len(DATA['data_per_streamline_func'])
447469

@@ -578,6 +600,7 @@ def test_tractogram_add_new_data(self):
578600
t.data_per_streamline['mean_curvature'] = DATA['mean_curvature']
579601
t.data_per_streamline['mean_torsion'] = DATA['mean_torsion']
580602
t.data_per_streamline['mean_colors'] = DATA['mean_colors']
603+
t.data_per_streamline['clusters_labels'] = DATA['clusters_labels']
581604
assert_tractogram_equal(t, DATA['tractogram'])
582605

583606
# Retrieve tractogram by their index.
@@ -598,6 +621,7 @@ def test_tractogram_add_new_data(self):
598621
t.data_per_streamline['mean_curvature'] = DATA['mean_curvature']
599622
t.data_per_streamline['mean_torsion'] = DATA['mean_torsion']
600623
t.data_per_streamline['mean_colors'] = DATA['mean_colors']
624+
t.data_per_streamline['clusters_labels'] = DATA['clusters_labels']
601625
assert_tractogram_equal(t, DATA['tractogram'])
602626

603627
def test_tractogram_copy(self):
@@ -647,14 +671,6 @@ def test_creating_invalid_tractogram(self):
647671
with pytest.raises(ValueError):
648672
Tractogram(streamlines=DATA['streamlines'], data_per_point={'scalars': scalars})
649673

650-
# Inconsistent dimension for a data_per_streamline.
651-
properties = [[1.11, 1.22], [2.11], [3.11, 3.22]]
652-
653-
with pytest.raises(ValueError):
654-
Tractogram(
655-
streamlines=DATA['streamlines'], data_per_streamline={'properties': properties}
656-
)
657-
658674
# Too many dimension for a data_per_streamline.
659675
properties = [
660676
np.array([[1.11], [1.22]], dtype='f4'),
@@ -870,6 +886,7 @@ def test_lazy_tractogram_from_data_func(self):
870886
DATA['mean_curvature'],
871887
DATA['mean_torsion'],
872888
DATA['mean_colors'],
889+
DATA['clusters_labels'],
873890
]
874891

875892
def _data_gen():
@@ -879,6 +896,7 @@ def _data_gen():
879896
'mean_curvature': d[3],
880897
'mean_torsion': d[4],
881898
'mean_colors': d[5],
899+
'clusters_labels': d[6],
882900
}
883901
yield TractogramItem(d[0], data_for_streamline, data_for_points)
884902

nibabel/streamlines/tractogram.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import numbers
3-
from collections.abc import MutableMapping
3+
import types
4+
from collections.abc import Iterable, MutableMapping
45
from warnings import warn
56

67
import numpy as np
@@ -101,15 +102,28 @@ def __init__(self, n_rows=0, *args, **kwargs):
101102
super().__init__(*args, **kwargs)
102103

103104
def __setitem__(self, key, value):
104-
value = np.asarray(list(value))
105+
dtype = np.float64
106+
107+
if isinstance(value, types.GeneratorType):
108+
value = list(value)
109+
110+
if isinstance(value, np.ndarray):
111+
dtype = value.dtype
112+
elif not all(len(v) == len(value[0]) for v in value[1:]):
113+
dtype = object
114+
115+
value = np.asarray(value, dtype=dtype)
105116

106117
if value.ndim == 1 and value.dtype != object:
107118
# Reshape without copy
108119
value.shape = (len(value), 1)
109120

110-
if value.ndim != 2:
121+
if value.ndim != 2 and value.dtype != object:
111122
raise ValueError('data_per_streamline must be a 2D array.')
112123

124+
if value.dtype == object and not all(isinstance(v, Iterable) for v in value):
125+
raise ValueError('data_per_streamline must be a 2D array')
126+
113127
# We make sure there is the right amount of values
114128
if 0 < self.n_rows != len(value):
115129
msg = f'The number of values ({len(value)}) should match n_elements ({self.n_rows}).'

0 commit comments

Comments
 (0)