Skip to content

Commit fb180e8

Browse files
authored
import normalize_axis_index from numpy.lib on numpy>=2 (#364)
* import `normalize_axis_index` from `numpy.lib` on `numpy>=2` * import the right thing
1 parent 0083ab2 commit fb180e8

File tree

1 file changed

+31
-24
lines changed

1 file changed

+31
-24
lines changed

flox/xrutils.py

+31-24
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import numpy as np
1010
import pandas as pd
11-
from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined]
1211
from packaging.version import Version
1312

1413
try:
@@ -25,6 +24,37 @@
2524
dask_array_type = () # type: ignore[assignment, misc]
2625

2726

27+
def module_available(module: str, minversion: Optional[str] = None) -> bool:
28+
"""Checks whether a module is installed without importing it.
29+
30+
Use this for a lightweight check and lazy imports.
31+
32+
Parameters
33+
----------
34+
module : str
35+
Name of the module.
36+
37+
Returns
38+
-------
39+
available : bool
40+
Whether the module is installed.
41+
"""
42+
has = importlib.util.find_spec(module) is not None
43+
if has:
44+
mod = importlib.import_module(module)
45+
return Version(mod.__version__) >= Version(minversion) if minversion is not None else True
46+
else:
47+
return False
48+
49+
50+
if module_available("numpy", minversion="2.0.0"):
51+
from numpy.lib.array_utils import ( # type: ignore[import-not-found]
52+
normalize_axis_index,
53+
)
54+
else:
55+
from numpy.core.numeric import normalize_axis_index # type: ignore[attr-defined]
56+
57+
2858
def asarray(data, xp=np):
2959
return data if is_duck_array(data) else xp.asarray(data)
3060

@@ -349,26 +379,3 @@ def nanlast(values, axis, keepdims=False):
349379
return np.expand_dims(result, axis=axis)
350380
else:
351381
return result
352-
353-
354-
def module_available(module: str, minversion: Optional[str] = None) -> bool:
355-
"""Checks whether a module is installed without importing it.
356-
357-
Use this for a lightweight check and lazy imports.
358-
359-
Parameters
360-
----------
361-
module : str
362-
Name of the module.
363-
364-
Returns
365-
-------
366-
available : bool
367-
Whether the module is installed.
368-
"""
369-
has = importlib.util.find_spec(module) is not None
370-
if has:
371-
mod = importlib.import_module(module)
372-
return Version(mod.__version__) >= Version(minversion) if minversion is not None else True
373-
else:
374-
return False

0 commit comments

Comments
 (0)