Skip to content

Commit f913068

Browse files
authored
unpin xarray, numpy, pandas, netcdf4 (#25)
* unpin xarray, numpy, pandas, netcdf4 * Fix deprecation "TypeError: Using a DataArray object to construct a variable is ambiguous, please extract the data using the .data property." See pydata/xarray#6508. There are other errors, however. * See Unidata/netcdf4-python#1175 "Regression in createVariable between 1.5.8 and 1.6.0". * Testing on Python 3.7 only covers through xarray 0.20.2. This is an experiment. * Fix get_metadata to account for the internal restructuring of indexes in xarray 2022.06. Update tests for same. * Fixed bug where we were stripping off attrs inadvertently when writing netCDF. * TestPlainGroupby.test_on_data_array fails with Python = 3.8.13, numpy = 1.23.2, xarray = 2022.06.0. That means it has nothing to do with BrainIO. * test_on_data_array should not involve any BrainIO classes. This is to test for bugs in xarray. With xarray==2022.06.0, this test fails. * xarray 2022.06.0 has a bug which breaks BrainIO: pydata/xarray#6836. * Adapt get_metadata to the change in the index API between 2022.03.0 and 2022.06.0. Now test_get_metadata passes under 2022.03.0 and 2022.06.0. * Getting an error from tests on Travis (but not locally): RuntimeError: NetCDF: Filter error: bad id or parameters or duplicate filter. This might fix it? * Compression test failed: assert 614732 > 615186. This might fix it. * Travis doesn't offer python 3.10 yet. Make sample assembly bigger so compression has an effect. * Bump minor version. Authored-by: Jonathan Prescott-Roy <[email protected]> and Martin Schrimpf <[email protected]>
1 parent 21f215f commit f913068

File tree

7 files changed

+190
-61
lines changed

7 files changed

+190
-61
lines changed

.travis.yml

+6-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@ matrix:
33
include:
44
- name: 3.7 public
55
python: '3.7'
6-
- name: 3.7 private
7-
python: '3.7'
6+
- name: 3.8 public
7+
python: '3.8'
8+
- name: 3.9 public
9+
python: '3.9'
10+
- name: 3.9 private
11+
python: '3.9'
812
env:
913
- PRIVATE_ACCESS=1
1014
- secure: "CzOQmNMkHXavXihZWYL+G5sbdYq8KLrBWnorZEPhvsKDIKy1hhORCc+pAMXg+bjrPRXfRqZnX0XRRCoZbD9Mo9VvA1hIsV7i5bBbjMoyBTUn3vED0CQNBCgjaA2rLsHlJMtYdLoCOOAiaU+rTu2xxf0grjgKARzLpVNENmPgP0YqiXPEc7rdY3cifalCBpHTQgvu7Z6FR1yAdRsMfskTIwPa/GlTCNF8ZR+efuobQJrtApfzBgiH7+NJI5Aq6u8PWD6LqONCm2ut0NKL7BMNRMgwS3pjERr2spRWrLiCz05Y4icaUmhajPjCl3kMIjuHdw1OgvwQHuSW9hcgt0AXZoIC8qJqg5V39LrsYYPd5/sg7vcTZ+VRhWF5zDBMvTO0PFt36tpj9xnr2ATIPlp1ACXwi+fGPkPAJp3ZIHbl36lji6sB4WLwIISongseizqTAHKowmpCGqEL6TZB65/MThWBeccRNB1N4a3wG34Eu7n1XXqecK1c+68JO98fOQxwmQ/utOkQRcVQzmGyARUk7WyupoqMmAZbWxOJ5AzyXPiK2OGXmiVJSwlMQKtF7eqkLs8wWeQD+zQj2qoSqF45LdFQsww19W2wC0wHuTV6nDBaKB59lY5qFufDWT+Gh06jLk8UpgYANh9f3fH5ZgUKfnH7I17StuDEpxCZ1kxVKcA="

brainio/assemblies.py

+47-3
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,8 @@ def array_is_element(arr, element):
328328
return len(arr) == 1 and arr[0] == element
329329

330330

331-
def get_metadata(assembly, dims=None, names_only=False, include_coords=True,
332-
include_indexes=True, include_multi_indexes=False, include_levels=True):
331+
def get_metadata_before_2022_06(assembly, dims=None, names_only=False, include_coords=True,
332+
include_indexes=True, include_multi_indexes=False, include_levels=True):
333333
"""
334334
Return coords and/or indexes or index levels from an assembly, yielding either `name` or `(name, dims, values)`.
335335
"""
@@ -362,6 +362,50 @@ def what(name, dims, values, names_only):
362362
yield what(name, values.dims, values.values, names_only)
363363

364364

365+
def get_metadata_after_2022_06(assembly, dims=None, names_only=False, include_coords=True,
366+
include_indexes=True, include_multi_indexes=False, include_levels=True):
367+
"""
368+
Return coords and/or indexes or index levels from an assembly, yielding either `name` or `(name, dims, values)`.
369+
"""
370+
def what(name, dims, values, names_only):
371+
if names_only:
372+
return name
373+
else:
374+
return name, dims, values
375+
if dims is None:
376+
dims = assembly.dims + (None,) # all dims plus dimensionless coords
377+
for name, values in assembly.coords.items():
378+
none_but_keep = (not values.dims) and None in dims
379+
shared = not (set(values.dims).isdisjoint(set(dims)))
380+
if none_but_keep or shared:
381+
if name in assembly.indexes: # it's an index
382+
index = assembly.indexes[name]
383+
if len(index.names) > 1: # it's a MultiIndex or level
384+
if name in index.names: # it's a level
385+
if include_levels:
386+
yield what(name, values.dims, values.values, names_only)
387+
else: # it's a MultiIndex
388+
if include_multi_indexes:
389+
yield what(name, values.dims, values.values, names_only)
390+
else: # it's a single Index
391+
if include_indexes:
392+
yield what(name, values.dims, values.values, names_only)
393+
else: # it's a coord
394+
if include_coords:
395+
yield what(name, values.dims, values.values, names_only)
396+
397+
398+
def get_metadata(assembly, dims=None, names_only=False, include_coords=True,
399+
include_indexes=True, include_multi_indexes=False, include_levels=True):
400+
try:
401+
xr.DataArray().stack(create_index=True)
402+
yield from get_metadata_after_2022_06(assembly, dims, names_only, include_coords,
403+
include_indexes, include_multi_indexes, include_levels)
404+
except TypeError as e:
405+
yield from get_metadata_before_2022_06(assembly, dims, names_only, include_coords,
406+
include_indexes, include_multi_indexes, include_levels)
407+
408+
365409
def coords_for_dim(assembly, dim):
366410
result = OrderedDict()
367411
meta = get_metadata(assembly, dims=(dim,), include_indexes=False, include_levels=False)
@@ -415,7 +459,7 @@ def correct_stimulus_id_name(cls, assembly):
415459
names = list(get_metadata(assembly, dims=('presentation',), names_only=True))
416460
if 'image_id' in names and 'stimulus_id' not in names:
417461
assembly = assembly.assign_coords(
418-
stimulus_id=('presentation', assembly['image_id']),
462+
stimulus_id=('presentation', assembly['image_id'].data),
419463
)
420464
return assembly
421465

brainio/packaging.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,9 @@ def write_netcdf(assembly, target_netcdf_file, append=False, group=None, compres
208208
mode = "a" if append else "w"
209209
target_netcdf_file.parent.mkdir(parents=True, exist_ok=True)
210210
if compress:
211-
ds = assembly.to_dataset(name="data")
212-
compression = dict(zlib=True, complevel=1)
213-
encoding = {var: compression for var in ds.variables}
211+
ds = assembly.to_dataset(name="data", promote_attrs=True)
212+
compression = dict(zlib=True, complevel=9)
213+
encoding = {var: compression for var in ds.data_vars}
214214
ds.to_netcdf(target_netcdf_file, mode=mode, group=group, encoding=encoding)
215215
else:
216216
assembly.to_netcdf(target_netcdf_file, mode=mode, group=group)

setup.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
"tqdm",
1313
"Pillow",
1414
"entrypoints",
15-
"numpy>=1.16.5, !=1.21.*",
16-
"pandas>=1.2.0, !=1.3.0",
17-
"xarray==0.17.0",
18-
"netcdf4==1.5.8",
15+
"numpy",
16+
"pandas",
17+
"xarray!=2022.06.0", # 2022.06.0 has a bug which breaks BrainIO: https://github.com/pydata/xarray/issues/6836
18+
"netcdf4!=1.6.0", # https://github.com/Unidata/netcdf4-python/issues/1175,
1919
]
2020

2121
setup(
2222
name='brainio',
23-
version='0.1.0',
23+
version='0.2.0',
2424
description="Data management for quantitative comparison of brains and brain-inspired systems",
2525
long_description=readme,
2626
author="Jon Prescott-Roy, Martin Schrimpf",

tests/conftest.py

+19-20
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,8 @@ def make_proto_assembly():
5555

5656

5757
def scattered_floats(lo, hi, num):
58-
# a kludge: looks stochastic, but deterministic
59-
mid = (hi + lo) / 2
60-
half = mid - lo
61-
jump = 8
62-
return [mid + np.sin(x) * half for x in range(2, num * (jump + 1), jump)][:num]
58+
rng = np.random.default_rng(12345)
59+
return rng.random(num) * (hi - lo) + lo
6360

6461

6562
# taken from values in /braintree/data2/active/users/sachis/projects/oasis900/monkeys/oleo/mworksproc/oleo_oasis900_210216_113846_mwk.csv
@@ -90,23 +87,25 @@ def make_meta_assembly():
9087
return a
9188

9289

93-
def make_spk_assembly():
90+
def make_spk_assembly(magnitude=3):
91+
size = 10**magnitude
92+
half = int((10**magnitude) / 2)
9493
coords = {
95-
"neuroid_id": ("event", ["A-019", "D-009"]*500),
96-
"project": ("event", ["test"]*1000),
97-
"datetime": ("event", np.repeat(np.datetime64('2021-02-16T11:41:55.000000000'), 1000)),
98-
"animal": ("event", ["testo"]*1000),
99-
"hemisphere": ("event", ["L", "R"]*500),
100-
"region": ("event", ["V4", "IT"]*500),
101-
"subregion": ("event", ["V4", "aIT"]*500),
102-
"array": ("event", ["6250-002416", "4865-233455"]*500),
103-
"bank": ("event", ["A", "D"]*500),
104-
"electrode": ("event", ["019", "009"]*500),
105-
"column": ("event", [5, 2]*500),
106-
"row": ("event", [4, 8]*500),
107-
"label": ("event", ["elec46", "elec123"]*500),
94+
"neuroid_id": ("event", ["A-019", "D-009"]*half),
95+
"project": ("event", ["test"]*size),
96+
"datetime": ("event", np.repeat(np.datetime64('2021-02-16T11:41:55.000000000'), size)),
97+
"animal": ("event", ["testo"]*size),
98+
"hemisphere": ("event", ["L", "R"]*half),
99+
"region": ("event", ["V4", "IT"]*half),
100+
"subregion": ("event", ["V4", "aIT"]*half),
101+
"array": ("event", ["6250-002416", "4865-233455"]*half),
102+
"bank": ("event", ["A", "D"]*half),
103+
"electrode": ("event", ["019", "009"]*half),
104+
"column": ("event", [5, 2]*half),
105+
"row": ("event", [4, 8]*half),
106+
"label": ("event", ["elec46", "elec123"]*half),
108107
}
109-
data = sorted(scattered_floats(67.7, 21116.2, 1000))
108+
data = sorted(scattered_floats(67.7, 21116.2, size))
110109
a = SpikeTimesAssembly(
111110
data=data,
112111
coords=coords,

tests/test_assemblies.py

+108-26
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,48 @@
1515
SpikeTimesAssembly, get_metadata
1616

1717

18+
def test_get_metadata():
19+
xr.show_versions()
20+
# assembly, dims, names_only, include_coords, include_indexes, include_multi_indexes, include_levels
21+
assy = DataAssembly(
22+
data=[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15], [16, 17, 18]],
23+
coords={
24+
'up': ("a", ['alpha', 'alpha', 'beta', 'beta', 'beta', 'beta']),
25+
'down': ("a", [1, 1, 1, 1, 2, 2]),
26+
'why': ("a", ['yes', 'yes', 'yes', 'yes', 'yes', 'no']),
27+
'b': ('b', ['x', 'y', 'z']),
28+
},
29+
dims=['a', 'b']
30+
)
31+
assy = assy.reset_index('why')
32+
assert set(get_metadata(assy, None, True, True, True, True, True)) == {'a', 'up', 'down', 'why', 'b'}
33+
assert set(get_metadata(assy, None, True, True, True, True, False)) == {'a', 'why', 'b'}
34+
assert set(get_metadata(assy, None, True, True, True, False, True)) == {'up', 'down', 'why', 'b'}
35+
assert set(get_metadata(assy, None, True, True, True, False, False)) == {'why', 'b'}
36+
assert set(get_metadata(assy, None, True, True, False, True, True)) == {'a', 'up', 'down', 'why'}
37+
assert set(get_metadata(assy, None, True, True, False, True, False)) == {'a', 'why'}
38+
assert set(get_metadata(assy, None, True, True, False, False, True)) == {'up', 'down', 'why'}
39+
assert set(get_metadata(assy, None, True, True, False, False, False)) == {'why'}
40+
assert set(get_metadata(assy, None, True, False, True, True, True)) == {'a', 'up', 'down', 'b'}
41+
assert set(get_metadata(assy, None, True, False, True, True, False)) == {'a', 'b'}
42+
assert set(get_metadata(assy, None, True, False, True, False, True)) == {'up', 'down', 'b'}
43+
assert set(get_metadata(assy, None, True, False, True, False, False)) == {'b'}
44+
assert set(get_metadata(assy, None, True, False, False, True, True)) == {'a', 'up', 'down'}
45+
assert set(get_metadata(assy, None, True, False, False, True, False)) == {'a'}
46+
assert set(get_metadata(assy, None, True, False, False, False, True)) == {'up', 'down'}
47+
assert set(get_metadata(assy, None, True, False, False, False, False)) == set()
48+
49+
a = make_proto_assembly()
50+
md_all = list(get_metadata(a))
51+
assert len(md_all) == 4
52+
md_coo = list(get_metadata(a, include_indexes=False, include_levels=False))
53+
assert len(md_coo) == 0
54+
md_ind = list(get_metadata(a, include_coords=False, include_indexes=True, include_multi_indexes=True, include_levels=False))
55+
assert len(md_ind) == 2
56+
md_lev = list(get_metadata(a, include_coords=False, include_indexes=False))
57+
assert len(md_lev) == 4
58+
59+
1860
def test_get_levels():
1961
assy = DataAssembly(
2062
data=[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15], [16, 17, 18]],
@@ -40,12 +82,11 @@ def test_wrap_dataarray(self):
4082
dims=['a', 'b']
4183
)
4284
assert "up" in da.coords
43-
assert da["a"].variable.level_names is None
85+
assert "a" not in da.indexes
4486
da = gather_indexes(da)
45-
assert da.coords.variables["a"].level_names == ["up", "down"]
46-
assert da["a"].variable.level_names == ["up", "down"]
87+
assert da.indexes["a"].names == ["up", "down"]
4788
da = DataArray(da)
48-
assert da.coords.variables["a"].level_names == ["up", "down"]
89+
assert da.indexes["a"].names == ["up", "down"]
4990
assert da["up"] is not None
5091

5192
def test_wrap_dataassembly(self):
@@ -58,11 +99,9 @@ def test_wrap_dataassembly(self):
5899
},
59100
dims=['a', 'b']
60101
)
61-
assert assy.coords.variables["a"].level_names == ["up", "down"]
62-
assert assy["a"].variable.level_names == ["up", "down"]
102+
assert assy.indexes["a"].names == ["up", "down"]
63103
da = DataArray(assy)
64-
assert da.coords.variables["a"].level_names == ["up", "down"]
65-
assert da["a"].variable.level_names == ["up", "down"]
104+
assert assy.indexes["a"].names == ["up", "down"]
66105
assert da["up"] is not None
67106

68107
def test_reset_index(self):
@@ -109,6 +148,7 @@ def test_getitem(self):
109148
)
110149
single = assy[0, 0]
111150
assert type(single) is type(assy)
151+
assert single == 1
112152

113153
def test_is_fastpath(self):
114154
"""In DataAssembly.__init__ we have to check whether fastpath is present in a set of arguments and true
@@ -164,17 +204,14 @@ def test_align(self):
164204
dims=['a', 'b']
165205
)
166206
assert hasattr(da1, "up")
167-
assert da1.coords.variables["a"].level_names == ["up", "down"]
168-
assert da1["a"].variable.level_names == ["up", "down"]
207+
assert da1.indexes["a"].names == ["up", "down"]
169208
assert da1["up"] is not None
170209
aligned1, aligned2 = xr.align(da1, da2, join="outer")
171210
assert hasattr(aligned1, "up")
172-
assert aligned1.coords.variables["a"].level_names == ["up", "down"]
173-
assert aligned1["a"].variable.level_names == ["up", "down"]
211+
assert aligned1.indexes["a"].names == ["up", "down"]
174212
assert aligned1["up"] is not None
175213
assert hasattr(aligned2, "up")
176-
assert aligned2.coords.variables["a"].level_names == ["up", "down"]
177-
assert aligned2["a"].variable.level_names == ["up", "down"]
214+
assert aligned2.indexes["a"].names == ["up", "down"]
178215
assert aligned2["up"] is not None
179216

180217

@@ -202,6 +239,60 @@ def test_incorrect_coord(self):
202239
d.sel(coordB=0)
203240

204241

242+
class TestPlainGroupy:
243+
244+
def test_on_data_array(self):
245+
d = DataArray(
246+
data=[
247+
[0, 1, 2, 3, 4, 5, 6],
248+
[7, 8, 9, 10, 11, 12, 13],
249+
[14, 15, 16, 17, 18, 19, 20]
250+
],
251+
coords={
252+
"greek": ("a", ['alpha', 'beta', 'gamma']),
253+
"colors": ("a", ['red', 'green', 'blue']),
254+
"compass": ("b", ['north', 'south', 'east', 'west', 'northeast', 'southeast', 'southwest']),
255+
"integer": ("b", [0, 1, 2, 3, 4, 5, 6]),
256+
},
257+
dims=("a", "b")
258+
)
259+
d = gather_indexes(d)
260+
g = d.groupby('greek')
261+
# with xarray==2022.06.0, the following line fails with:
262+
# ValueError: conflicting multi-index level name 'greek' with dimension 'greek'
263+
m = g.mean(...)
264+
c = DataArray(
265+
data=[3, 10, 17],
266+
coords={'greek': ('greek', ['alpha', 'beta', 'gamma'])},
267+
dims=['greek']
268+
)
269+
assert m.equals(c)
270+
271+
def test_on_data_assembly(self):
272+
d = DataAssembly(
273+
data=[
274+
[0, 1, 2, 3, 4, 5, 6],
275+
[7, 8, 9, 10, 11, 12, 13],
276+
[14, 15, 16, 17, 18, 19, 20]
277+
],
278+
coords={
279+
"greek": ("a", ['alpha', 'beta', 'gamma']),
280+
"colors": ("a", ['red', 'green', 'blue']),
281+
"compass": ("b", ['north', 'south', 'east', 'west', 'northeast', 'southeast', 'southwest']),
282+
"integer": ("b", [0, 1, 2, 3, 4, 5, 6]),
283+
},
284+
dims=("a", "b")
285+
)
286+
g = d.groupby('greek')
287+
m = g.mean(...)
288+
c = DataAssembly(
289+
data=[3, 10, 17],
290+
coords={'greek': ('greek', ['alpha', 'beta', 'gamma'])},
291+
dims=['greek']
292+
)
293+
assert m.equals(c)
294+
295+
205296
class TestMultiGroupby:
206297
def test_single_dimension(self):
207298
d = DataAssembly([[1, 2, 3], [4, 5, 6]], coords={'a': ['a', 'b'], 'b': ['x', 'y', 'z']}, dims=['a', 'b'])
@@ -228,13 +319,14 @@ def test_single_coord(self):
228319
},
229320
dims=("a", "b")
230321
)
231-
g = d.multi_groupby(['greek']).mean(...)
322+
g = d.multi_groupby(['greek'])
323+
m = g.mean(...)
232324
c = DataAssembly(
233325
data=[3, 10, 17],
234326
coords={'greek': ('greek', ['alpha', 'beta', 'gamma'])},
235327
dims=['greek']
236328
)
237-
assert g.equals(c)
329+
assert m.equals(c)
238330

239331
def test_single_dim_multi_coord(self):
240332
d = DataAssembly([1, 2, 3, 4, 5, 6],
@@ -452,15 +544,5 @@ def test_load_extras(self, test_stimulus_set_identifier):
452544
assert extra.shape == (40,)
453545

454546

455-
def test_get_metadata():
456-
a = make_proto_assembly()
457-
md_all = list(get_metadata(a))
458-
assert len(md_all) == 4
459-
md_coo = list(get_metadata(a, include_indexes=False, include_levels=False))
460-
assert len(md_coo) == 0
461-
md_ind = list(get_metadata(a, include_coords=False, include_indexes=True, include_multi_indexes=True, include_levels=False))
462-
assert len(md_ind) == 2
463-
md_lev = list(get_metadata(a, include_coords=False, include_indexes=False))
464-
assert len(md_lev) == 4
465547

466548

tests/test_packaging.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,9 @@ def test_package_extras(test_stimulus_set_identifier, test_catalog_identifier, b
176176

177177

178178
def test_compression(test_write_netcdf_path):
179-
write_netcdf(make_spk_assembly(), test_write_netcdf_path, compress=False)
179+
write_netcdf(make_spk_assembly(6), test_write_netcdf_path, compress=False)
180180
uncompressed = test_write_netcdf_path.stat().st_size
181-
write_netcdf(make_spk_assembly(), test_write_netcdf_path, compress=True)
181+
write_netcdf(make_spk_assembly(6), test_write_netcdf_path, compress=True)
182182
compressed = test_write_netcdf_path.stat().st_size
183183
assert uncompressed > compressed
184184

0 commit comments

Comments
 (0)