@@ -80,6 +80,7 @@ def make_dummy_streamline(nb_points):
80
80
'mean_curvature' : np .array ([1.11 ], dtype = 'f4' ),
81
81
'mean_torsion' : np .array ([1.22 ], dtype = 'f4' ),
82
82
'mean_colors' : np .array ([1 , 0 , 0 ], dtype = 'f4' ),
83
+ 'clusters_labels' : np .array ([0 , 1 ], dtype = 'i4' ),
83
84
}
84
85
85
86
elif nb_points == 2 :
@@ -92,6 +93,7 @@ def make_dummy_streamline(nb_points):
92
93
'mean_curvature' : np .array ([2.11 ], dtype = 'f4' ),
93
94
'mean_torsion' : np .array ([2.22 ], dtype = 'f4' ),
94
95
'mean_colors' : np .array ([0 , 1 , 0 ], dtype = 'f4' ),
96
+ 'clusters_labels' : np .array ([2 , 3 , 4 ], dtype = 'i4' ),
95
97
}
96
98
97
99
elif nb_points == 5 :
@@ -104,6 +106,7 @@ def make_dummy_streamline(nb_points):
104
106
'mean_curvature' : np .array ([3.11 ], dtype = 'f4' ),
105
107
'mean_torsion' : np .array ([3.22 ], dtype = 'f4' ),
106
108
'mean_colors' : np .array ([0 , 0 , 1 ], dtype = 'f4' ),
109
+ 'clusters_labels' : np .array ([5 , 6 , 7 , 8 ], dtype = 'i4' ),
107
110
}
108
111
109
112
return streamline , data_per_point , data_for_streamline
@@ -119,6 +122,7 @@ def setup_module():
119
122
DATA ['mean_curvature' ] = []
120
123
DATA ['mean_torsion' ] = []
121
124
DATA ['mean_colors' ] = []
125
+ DATA ['clusters_labels' ] = []
122
126
for nb_points in [1 , 2 , 5 ]:
123
127
data = make_dummy_streamline (nb_points )
124
128
streamline , data_per_point , data_for_streamline = data
@@ -128,12 +132,14 @@ def setup_module():
128
132
DATA ['mean_curvature' ].append (data_for_streamline ['mean_curvature' ])
129
133
DATA ['mean_torsion' ].append (data_for_streamline ['mean_torsion' ])
130
134
DATA ['mean_colors' ].append (data_for_streamline ['mean_colors' ])
135
+ DATA ['clusters_labels' ].append (data_for_streamline ['clusters_labels' ])
131
136
132
137
DATA ['data_per_point' ] = {'colors' : DATA ['colors' ], 'fa' : DATA ['fa' ]}
133
138
DATA ['data_per_streamline' ] = {
134
139
'mean_curvature' : DATA ['mean_curvature' ],
135
140
'mean_torsion' : DATA ['mean_torsion' ],
136
141
'mean_colors' : DATA ['mean_colors' ],
142
+ 'clusters_labels' : DATA ['clusters_labels' ],
137
143
}
138
144
139
145
DATA ['empty_tractogram' ] = Tractogram (affine_to_rasmm = np .eye (4 ))
@@ -154,6 +160,7 @@ def setup_module():
154
160
'mean_curvature' : lambda : (e for e in DATA ['mean_curvature' ]),
155
161
'mean_torsion' : lambda : (e for e in DATA ['mean_torsion' ]),
156
162
'mean_colors' : lambda : (e for e in DATA ['mean_colors' ]),
163
+ 'clusters_labels' : lambda : (e for e in DATA ['clusters_labels' ]),
157
164
}
158
165
159
166
DATA ['lazy_tractogram' ] = LazyTractogram (
@@ -214,7 +221,10 @@ def test_per_array_dict_creation(self):
214
221
data_dict = PerArrayDict (nb_streamlines , data_per_streamline )
215
222
assert data_dict .keys () == data_per_streamline .keys ()
216
223
for k in data_dict .keys ():
217
- assert_array_equal (data_dict [k ], data_per_streamline [k ])
224
+ if isinstance (data_dict [k ], np .ndarray ) and np .all (
225
+ data_dict [k ].shape [0 ] == data_dict [k ].shape
226
+ ):
227
+ assert_array_equal (data_dict [k ], data_per_streamline [k ])
218
228
219
229
del data_dict ['mean_curvature' ]
220
230
assert len (data_dict ) == len (data_per_streamline ) - 1
@@ -224,7 +234,10 @@ def test_per_array_dict_creation(self):
224
234
data_dict = PerArrayDict (nb_streamlines , data_per_streamline )
225
235
assert data_dict .keys () == data_per_streamline .keys ()
226
236
for k in data_dict .keys ():
227
- assert_array_equal (data_dict [k ], data_per_streamline [k ])
237
+ if isinstance (data_dict [k ], np .ndarray ) and np .all (
238
+ data_dict [k ].shape [0 ] == data_dict [k ].shape
239
+ ):
240
+ assert_array_equal (data_dict [k ], data_per_streamline [k ])
228
241
229
242
del data_dict ['mean_curvature' ]
230
243
assert len (data_dict ) == len (data_per_streamline ) - 1
@@ -234,7 +247,10 @@ def test_per_array_dict_creation(self):
234
247
data_dict = PerArrayDict (nb_streamlines , ** data_per_streamline )
235
248
assert data_dict .keys () == data_per_streamline .keys ()
236
249
for k in data_dict .keys ():
237
- assert_array_equal (data_dict [k ], data_per_streamline [k ])
250
+ if isinstance (data_dict [k ], np .ndarray ) and np .all (
251
+ data_dict [k ].shape [0 ] == data_dict [k ].shape
252
+ ):
253
+ assert_array_equal (data_dict [k ], data_per_streamline [k ])
238
254
239
255
del data_dict ['mean_curvature' ]
240
256
assert len (data_dict ) == len (data_per_streamline ) - 1
@@ -261,6 +277,7 @@ def test_extend(self):
261
277
'mean_curvature' : 2 * np .array (DATA ['mean_curvature' ]),
262
278
'mean_torsion' : 3 * np .array (DATA ['mean_torsion' ]),
263
279
'mean_colors' : 4 * np .array (DATA ['mean_colors' ]),
280
+ 'clusters_labels' : 5 * np .array (DATA ['clusters_labels' ], dtype = object ),
264
281
}
265
282
sdict2 = PerArrayDict (len (DATA ['tractogram' ]), new_data )
266
283
@@ -284,7 +301,8 @@ def test_extend(self):
284
301
'mean_curvature' : 2 * np .array (DATA ['mean_curvature' ]),
285
302
'mean_torsion' : 3 * np .array (DATA ['mean_torsion' ]),
286
303
'mean_colors' : 4 * np .array (DATA ['mean_colors' ]),
287
- 'other' : 5 * np .array (DATA ['mean_colors' ]),
304
+ 'clusters_labels' : 5 * np .array (DATA ['clusters_labels' ], dtype = object ),
305
+ 'other' : 6 * np .array (DATA ['mean_colors' ]),
288
306
}
289
307
sdict2 = PerArrayDict (len (DATA ['tractogram' ]), new_data )
290
308
@@ -305,6 +323,7 @@ def test_extend(self):
305
323
'mean_curvature' : 2 * np .array (DATA ['mean_curvature' ]),
306
324
'mean_torsion' : 3 * np .array (DATA ['mean_torsion' ]),
307
325
'mean_colors' : 4 * np .array (DATA ['mean_torsion' ]),
326
+ 'clusters_labels' : 5 * np .array (DATA ['clusters_labels' ], dtype = object ),
308
327
}
309
328
sdict2 = PerArrayDict (len (DATA ['tractogram' ]), new_data )
310
329
with pytest .raises (ValueError ):
@@ -441,7 +460,10 @@ def test_lazydict_creation(self):
441
460
assert is_lazy_dict (data_dict )
442
461
assert data_dict .keys () == expected_keys
443
462
for k in data_dict .keys ():
444
- assert_array_equal (list (data_dict [k ]), list (DATA ['data_per_streamline' ][k ]))
463
+ if isinstance (data_dict [k ], np .ndarray ) and np .all (
464
+ data_dict [k ].shape [0 ] == data_dict [k ].shape
465
+ ):
466
+ assert_array_equal (list (data_dict [k ]), list (DATA ['data_per_streamline' ][k ]))
445
467
446
468
assert len (data_dict ) == len (DATA ['data_per_streamline_func' ])
447
469
@@ -578,6 +600,7 @@ def test_tractogram_add_new_data(self):
578
600
t .data_per_streamline ['mean_curvature' ] = DATA ['mean_curvature' ]
579
601
t .data_per_streamline ['mean_torsion' ] = DATA ['mean_torsion' ]
580
602
t .data_per_streamline ['mean_colors' ] = DATA ['mean_colors' ]
603
+ t .data_per_streamline ['clusters_labels' ] = DATA ['clusters_labels' ]
581
604
assert_tractogram_equal (t , DATA ['tractogram' ])
582
605
583
606
# Retrieve tractogram by their index.
@@ -598,6 +621,7 @@ def test_tractogram_add_new_data(self):
598
621
t .data_per_streamline ['mean_curvature' ] = DATA ['mean_curvature' ]
599
622
t .data_per_streamline ['mean_torsion' ] = DATA ['mean_torsion' ]
600
623
t .data_per_streamline ['mean_colors' ] = DATA ['mean_colors' ]
624
+ t .data_per_streamline ['clusters_labels' ] = DATA ['clusters_labels' ]
601
625
assert_tractogram_equal (t , DATA ['tractogram' ])
602
626
603
627
def test_tractogram_copy (self ):
@@ -647,14 +671,6 @@ def test_creating_invalid_tractogram(self):
647
671
with pytest .raises (ValueError ):
648
672
Tractogram (streamlines = DATA ['streamlines' ], data_per_point = {'scalars' : scalars })
649
673
650
- # Inconsistent dimension for a data_per_streamline.
651
- properties = [[1.11 , 1.22 ], [2.11 ], [3.11 , 3.22 ]]
652
-
653
- with pytest .raises (ValueError ):
654
- Tractogram (
655
- streamlines = DATA ['streamlines' ], data_per_streamline = {'properties' : properties }
656
- )
657
-
658
674
# Too many dimension for a data_per_streamline.
659
675
properties = [
660
676
np .array ([[1.11 ], [1.22 ]], dtype = 'f4' ),
@@ -870,6 +886,7 @@ def test_lazy_tractogram_from_data_func(self):
870
886
DATA ['mean_curvature' ],
871
887
DATA ['mean_torsion' ],
872
888
DATA ['mean_colors' ],
889
+ DATA ['clusters_labels' ],
873
890
]
874
891
875
892
def _data_gen ():
@@ -879,6 +896,7 @@ def _data_gen():
879
896
'mean_curvature' : d [3 ],
880
897
'mean_torsion' : d [4 ],
881
898
'mean_colors' : d [5 ],
899
+ 'clusters_labels' : d [6 ],
882
900
}
883
901
yield TractogramItem (d [0 ], data_for_streamline , data_for_points )
884
902
0 commit comments