@@ -120,22 +120,24 @@ inline numpy_internals &get_numpy_internals() {
120
120
return *ptr;
121
121
}
122
122
123
- PYBIND11_NOINLINE module_ import_numpy_core_submodule (const char *submodule_name) {
124
- try {
125
- return module_::import ((std::string (" numpy._core." ) + submodule_name).c_str ());
126
- } catch (error_already_set &ex) {
127
- if (!ex.matches (PyExc_ImportError)) {
128
- throw ;
129
- }
130
- try {
131
- return module_::import ((std::string (" numpy.core." ) + submodule_name).c_str ());
132
- } catch (error_already_set &ex) {
133
- if (!ex.matches (PyExc_ImportError)) {
134
- throw ;
135
- }
136
- throw import_error (std::string (" pybind11 couldn't import " ) + submodule_name
137
- + " from numpy." );
138
- }
123
+ PYBIND11_NOINLINE module_ import_numpy_core_submodule (const char * submodule_name) {
124
+ module_ numpy = module_::import (" numpy" );
125
+ str version_string = numpy.attr (" __version__" );
126
+
127
+ module_ numpy_lib = module_::import (" numpy.lib" );
128
+ object numpy_version = numpy_lib.attr (" NumpyVersion" )(version_string);
129
+ int major_version = numpy_version.attr (" major" ).cast <int >();
130
+
131
+ /* `numpy.core` was renamed to `numpy._core` in NumPy 2.0 as it officially
132
+ became a private module. */
133
+ if (major_version >= 2 ) {
134
+ return py::module_::import (
135
+ (std::string (" numpy._core." ) + submodule_name).c_str ()
136
+ );
137
+ } else {
138
+ return py::module_::import (
139
+ (std::string (" numpy.core." ) + submodule_name).c_str ()
140
+ );
139
141
}
140
142
}
141
143
@@ -285,6 +287,9 @@ struct npy_api {
285
287
module_ m = detail::import_numpy_core_submodule (" multiarray" );
286
288
auto c = m.attr (" _ARRAY_API" );
287
289
void **api_ptr = (void **) PyCapsule_GetPointer (c.ptr (), nullptr );
290
+ if (api_ptr == nullptr ) {
291
+ raise_from (PyExc_SystemError, " FAILURE obtaining numpy _ARRAY_API pointer." );
292
+ }
288
293
npy_api api;
289
294
#define DECL_NPY_API (Func ) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
290
295
DECL_NPY_API (PyArray_GetNDArrayCFeatureVersion);
0 commit comments