Skip to content

Commit 4a3464f

Browse files
EricCousineau-TRIwjakob
authored andcommitted
numpy: Provide concrete size aliases
Test for dtype checks now succeed without warnings
1 parent e9ca89f commit 4a3464f

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

include/pybind11/numpy.h

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <numeric>
1515
#include <algorithm>
1616
#include <array>
17+
#include <cstdint>
1718
#include <cstdlib>
1819
#include <cstring>
1920
#include <sstream>
@@ -108,6 +109,18 @@ inline numpy_internals& get_numpy_internals() {
108109
return *ptr;
109110
}
110111

112+
template <typename T> struct same_size {
113+
template <typename U> using as = bool_constant<sizeof(T) == sizeof(U)>;
114+
};
115+
116+
// Lookup a type according to its size, and return a value corresponding to the NumPy typenum.
117+
template <typename Concrete, typename... Check, typename... Int>
118+
constexpr int platform_lookup(Int... codes) {
119+
using code_index = std::integral_constant<int, constexpr_first<same_size<Concrete>::template as, Check...>()>;
120+
static_assert(code_index::value != sizeof...(Check), "Unable to match type on this platform");
121+
return std::get<code_index::value>(std::make_tuple(codes...));
122+
}
123+
111124
struct npy_api {
112125
enum constants {
113126
NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
@@ -126,7 +139,23 @@ struct npy_api {
126139
NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
127140
NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
128141
NPY_OBJECT_ = 17,
129-
NPY_STRING_, NPY_UNICODE_, NPY_VOID_
142+
NPY_STRING_, NPY_UNICODE_, NPY_VOID_,
143+
// Platform-dependent normalization
144+
NPY_INT8_ = NPY_BYTE_,
145+
NPY_UINT8_ = NPY_UBYTE_,
146+
NPY_INT16_ = NPY_SHORT_,
147+
NPY_UINT16_ = NPY_USHORT_,
148+
// `npy_common.h` defines the integer aliases. In order, it checks:
149+
// NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
150+
// and assigns the alias to the first matching size, so we should check in this order.
151+
NPY_INT32_ = platform_lookup<std::int32_t, long, int, short>(
152+
NPY_LONG_, NPY_INT_, NPY_SHORT_),
153+
NPY_UINT32_ = platform_lookup<std::uint32_t, unsigned long, unsigned int, unsigned short>(
154+
NPY_ULONG_, NPY_UINT_, NPY_USHORT_),
155+
NPY_INT64_ = platform_lookup<std::int64_t, long, long long, int>(
156+
NPY_LONG_, NPY_LONGLONG_, NPY_INT_),
157+
NPY_UINT64_ = platform_lookup<std::uint64_t, unsigned long, unsigned long long, unsigned int>(
158+
NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
130159
};
131160

132161
typedef struct {
@@ -1004,8 +1033,8 @@ struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmet
10041033
// NB: the order here must match the one in common.h
10051034
constexpr static const int values[15] = {
10061035
npy_api::NPY_BOOL_,
1007-
npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_,
1008-
npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_,
1036+
npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_INT16_, npy_api::NPY_UINT16_,
1037+
npy_api::NPY_INT32_, npy_api::NPY_UINT32_, npy_api::NPY_INT64_, npy_api::NPY_UINT64_,
10091038
npy_api::NPY_FLOAT_, npy_api::NPY_DOUBLE_, npy_api::NPY_LONGDOUBLE_,
10101039
npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_
10111040
};

0 commit comments

Comments
 (0)