Skip to content

Commit 2adef4b

Browse files
committed
Select submodule name based on numpy version
1 parent 45cc2a0 commit 2adef4b

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

include/pybind11/numpy.h

+21-16
Original file line numberDiff line numberDiff line change
@@ -120,22 +120,24 @@ inline numpy_internals &get_numpy_internals() {
120120
return *ptr;
121121
}
122122

123-
PYBIND11_NOINLINE module_ import_numpy_core_submodule(const char *submodule_name) {
124-
try {
125-
return module_::import((std::string("numpy._core.") + submodule_name).c_str());
126-
} catch (error_already_set &ex) {
127-
if (!ex.matches(PyExc_ImportError)) {
128-
throw;
129-
}
130-
try {
131-
return module_::import((std::string("numpy.core.") + submodule_name).c_str());
132-
} catch (error_already_set &ex) {
133-
if (!ex.matches(PyExc_ImportError)) {
134-
throw;
135-
}
136-
throw import_error(std::string("pybind11 couldn't import ") + submodule_name
137-
+ " from numpy.");
138-
}
123+
PYBIND11_NOINLINE module_ import_numpy_core_submodule(const char * submodule_name) {
124+
module_ numpy = module_::import("numpy");
125+
str version_string = numpy.attr("__version__");
126+
127+
module_ numpy_lib = module_::import("numpy.lib");
128+
object numpy_version = numpy_lib.attr("NumpyVersion")(version_string);
129+
int major_version = numpy_version.attr("major").cast<int>();
130+
131+
/* `numpy.core` was renamed to `numpy._core` in NumPy 2.0 as it officially
132+
became a private module. */
133+
if (major_version >= 2) {
134+
return py::module_::import(
135+
(std::string("numpy._core.") + submodule_name).c_str()
136+
);
137+
} else {
138+
return py::module_::import(
139+
(std::string("numpy.core.") + submodule_name).c_str()
140+
);
139141
}
140142
}
141143

@@ -285,6 +287,9 @@ struct npy_api {
285287
module_ m = detail::import_numpy_core_submodule("multiarray");
286288
auto c = m.attr("_ARRAY_API");
287289
void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), nullptr);
290+
if (api_ptr == nullptr) {
291+
raise_from(PyExc_SystemError, "FAILURE obtaining numpy _ARRAY_API pointer.");
292+
}
288293
npy_api api;
289294
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
290295
DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);

0 commit comments

Comments
 (0)