|
14 | 14 |
|
15 | 15 | #include <cstdint>
|
16 | 16 |
|
| 17 | +// Size / dtype checks. |
| 18 | +struct DtypeCheck { |
| 19 | + py::dtype numpy{}; |
| 20 | + py::dtype pybind11{}; |
| 21 | +}; |
| 22 | + |
| 23 | +template <typename T> |
| 24 | +DtypeCheck get_dtype_check(const char* name) { |
| 25 | + py::module np = py::module::import("numpy"); |
| 26 | + DtypeCheck check{}; |
| 27 | + check.numpy = np.attr("dtype")(np.attr(name)); |
| 28 | + check.pybind11 = py::dtype::of<T>(); |
| 29 | + return check; |
| 30 | +} |
| 31 | + |
| 32 | +std::vector<DtypeCheck> get_concrete_dtype_checks() { |
| 33 | + return { |
| 34 | + // Normalization |
| 35 | + get_dtype_check<std::int8_t>("int8"), |
| 36 | + get_dtype_check<std::uint8_t>("uint8"), |
| 37 | + get_dtype_check<std::int16_t>("int16"), |
| 38 | + get_dtype_check<std::uint16_t>("uint16"), |
| 39 | + get_dtype_check<std::int32_t>("int32"), |
| 40 | + get_dtype_check<std::uint32_t>("uint32"), |
| 41 | + get_dtype_check<std::int64_t>("int64"), |
| 42 | + get_dtype_check<std::uint64_t>("uint64") |
| 43 | + }; |
| 44 | +} |
| 45 | + |
| 46 | +struct DtypeSizeCheck { |
| 47 | + std::string name{}; |
| 48 | + int size_cpp{}; |
| 49 | + int size_numpy{}; |
| 50 | + // For debugging. |
| 51 | + py::dtype dtype{}; |
| 52 | +}; |
| 53 | + |
| 54 | +template <typename T> |
| 55 | +DtypeSizeCheck get_dtype_size_check() { |
| 56 | + DtypeSizeCheck check{}; |
| 57 | + check.name = py::type_id<T>(); |
| 58 | + check.size_cpp = sizeof(T); |
| 59 | + check.dtype = py::dtype::of<T>(); |
| 60 | + check.size_numpy = check.dtype.attr("itemsize").template cast<int>(); |
| 61 | + return check; |
| 62 | +} |
| 63 | + |
| 64 | +std::vector<DtypeSizeCheck> get_platform_dtype_size_checks() { |
| 65 | + return { |
| 66 | + get_dtype_size_check<short>(), |
| 67 | + get_dtype_size_check<unsigned short>(), |
| 68 | + get_dtype_size_check<int>(), |
| 69 | + get_dtype_size_check<unsigned int>(), |
| 70 | + get_dtype_size_check<long>(), |
| 71 | + get_dtype_size_check<unsigned long>(), |
| 72 | + get_dtype_size_check<long long>(), |
| 73 | + get_dtype_size_check<unsigned long long>(), |
| 74 | + }; |
| 75 | +} |
| 76 | + |
| 77 | +// Arrays. |
17 | 78 | using arr = py::array;
|
18 | 79 | using arr_t = py::array_t<uint16_t, 0>;
|
19 | 80 | static_assert(std::is_same<arr_t::value_type, uint16_t>::value, "");
|
@@ -75,6 +136,26 @@ TEST_SUBMODULE(numpy_array, sm) {
|
75 | 136 | try { py::module::import("numpy"); }
|
76 | 137 | catch (...) { return; }
|
77 | 138 |
|
| 139 | + // test_dtypes |
| 140 | + py::class_<DtypeCheck>(sm, "DtypeCheck") |
| 141 | + .def_readonly("numpy", &DtypeCheck::numpy) |
| 142 | + .def_readonly("pybind11", &DtypeCheck::pybind11) |
| 143 | + .def("__repr__", [](const DtypeCheck& self) { |
| 144 | + return py::str("<DtypeCheck numpy={} pybind11={}>").format( |
| 145 | + self.numpy, self.pybind11); |
| 146 | + }); |
| 147 | + sm.def("get_concrete_dtype_checks", &get_concrete_dtype_checks); |
| 148 | + |
| 149 | + py::class_<DtypeSizeCheck>(sm, "DtypeSizeCheck") |
| 150 | + .def_readonly("name", &DtypeSizeCheck::name) |
| 151 | + .def_readonly("size_cpp", &DtypeSizeCheck::size_cpp) |
| 152 | + .def_readonly("size_numpy", &DtypeSizeCheck::size_numpy) |
| 153 | + .def("__repr__", [](const DtypeSizeCheck& self) { |
| 154 | + return py::str("<DtypeSizeCheck name='{}' size_cpp={} size_numpy={} dtype={}>").format( |
| 155 | + self.name, self.size_cpp, self.size_numpy, self.dtype); |
| 156 | + }); |
| 157 | + sm.def("get_platform_dtype_size_checks", &get_platform_dtype_size_checks); |
| 158 | + |
78 | 159 | // test_array_attributes
|
79 | 160 | sm.def("ndim", [](const arr& a) { return a.ndim(); });
|
80 | 161 | sm.def("shape", [](const arr& a) { return arr(a.ndim(), a.shape()); });
|
|
0 commit comments