Skip to content

Commit 4ca1173

Browse files
committed
Generic containers/iterators for shape/strides
This adds support for constructing `buffer_info` and `array`s using arbitrary containers or iterators instead of requiring a vector. On its own, this adds quite a few templated constructors which isn't all that desirable, but it's needed by PR pybind#782 to preserve backwards compatibility for previous pybind11 versions that accept a `std::vector<size_t>` for strides. Given a choice between duplicating all the stride-taking constructors (accepting both a ssize_t and size_t vector) and making the whole interface more general, I think this is a better approach. I also seem to recall some discussion a few months ago about wanting something like this, but I can't find the issue/PR where it was mentioned.
1 parent b42b1b0 commit 4ca1173

File tree

5 files changed

+114
-54
lines changed

5 files changed

+114
-54
lines changed

include/pybind11/buffer_info.h

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,37 @@ struct buffer_info {
2525

2626
buffer_info() { }
2727

28+
template <typename ShapeIt, typename StridesIt,
29+
typename = detail::enable_if_t<detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
2830
buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t ndim,
29-
const std::vector<size_t> &shape, const std::vector<size_t> &strides)
30-
: ptr(ptr), itemsize(itemsize), size(1), format(format),
31-
ndim(ndim), shape(shape), strides(strides) {
31+
ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last)
32+
: ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
33+
shape(shape_first, shape_last), strides(strides_first, strides_last) {
34+
if (ndim != shape.size() || ndim != strides.size())
35+
pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length");
3236
for (size_t i = 0; i < ndim; ++i)
3337
size *= shape[i];
3438
}
3539

40+
template <typename Shape, typename Strides,
41+
typename = decltype(std::begin(std::declval<const Shape &>())),
42+
typename = decltype(std::begin(std::declval<const Shape &>()))>
43+
buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t ndim,
44+
const Shape &shape, const Strides &strides)
45+
: buffer_info(ptr, itemsize, format, ndim, std::begin(shape), std::end(shape), std::begin(strides), std::end(strides)) { }
46+
47+
template <typename T1, typename T2>
48+
buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t ndim,
49+
const std::initializer_list<T1> &shape, const std::initializer_list<T2> &strides)
50+
: buffer_info(ptr, itemsize, format, ndim, shape.begin(), shape.end(), strides.begin(), strides.end()) { }
51+
3652
buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t size)
37-
: buffer_info(ptr, itemsize, format, 1, std::vector<size_t> { size },
38-
std::vector<size_t> { itemsize }) { }
39-
40-
explicit buffer_info(Py_buffer *view, bool ownview = true)
41-
: ptr(view->buf), itemsize((size_t) view->itemsize), size(1), format(view->format),
42-
ndim((size_t) view->ndim), shape((size_t) view->ndim), strides((size_t) view->ndim), view(view), ownview(ownview) {
43-
for (size_t i = 0; i < (size_t) view->ndim; ++i) {
44-
shape[i] = (size_t) view->shape[i];
45-
strides[i] = (size_t) view->strides[i];
46-
size *= shape[i];
47-
}
53+
: buffer_info(ptr, itemsize, format, 1, { size }, { itemsize }) { }
54+
55+
explicit buffer_info(Py_buffer *view, bool ownview_in = true)
56+
: buffer_info(view->buf, (size_t) view->itemsize, view->format, (size_t) view->ndim,
57+
view->shape, view->shape + view->ndim, view->strides, view->strides + view->ndim) {
58+
ownview = ownview_in;
4859
}
4960

5061
buffer_info(const buffer_info &) = delete;

include/pybind11/common.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,12 @@ struct is_instantiation<Class, Class<Us...>> : std::true_type { };
490490
/// Check if T is std::shared_ptr<U> where U can be anything
491491
template <typename T> using is_shared_ptr = is_instantiation<std::shared_ptr, T>;
492492

493+
/// Check if T looks like an input iterator
494+
template <typename T, typename = void> struct is_input_iterator : std::false_type {};
495+
template <typename T>
496+
struct is_input_iterator<T, void_t<decltype(*std::declval<T &>()), decltype(++std::declval<T &>())>>
497+
: std::true_type {};
498+
493499
/// Ignore that a variable is unused in compiler warnings
494500
inline void ignore_unused(const int *) { }
495501

include/pybind11/eigen.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -200,18 +200,13 @@ template <typename Type_> struct EigenProps {
200200
// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
201201
template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
202202
constexpr size_t elem_size = sizeof(typename props::Scalar);
203-
std::vector<size_t> shape, strides;
204-
if (props::vector) {
205-
shape.push_back(src.size());
206-
strides.push_back(elem_size * src.innerStride());
207-
}
208-
else {
209-
shape.push_back(src.rows());
210-
shape.push_back(src.cols());
211-
strides.push_back(elem_size * src.rowStride());
212-
strides.push_back(elem_size * src.colStride());
213-
}
214-
array a(std::move(shape), std::move(strides), src.data(), base);
203+
array a;
204+
if (props::vector)
205+
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
206+
else
207+
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
208+
src.data(), base);
209+
215210
if (!writeable)
216211
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
217212

include/pybind11/numpy.h

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -455,12 +455,15 @@ class array : public buffer {
455455

456456
array() : array(0, static_cast<const double *>(nullptr)) {}
457457

458-
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
459-
const std::vector<size_t> &strides, const void *ptr = nullptr,
460-
handle base = handle()) {
458+
template <typename ShapeIt, typename StridesIt,
459+
typename = detail::enable_if_t<detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
460+
array(const pybind11::dtype &dt, ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
461+
const void *ptr = nullptr, handle base = handle()) {
461462
auto& api = detail::npy_api::get();
463+
464+
std::vector<Py_intptr_t> shape(shape_first, shape_last), strides(strides_first, strides_last);
462465
auto ndim = shape.size();
463-
if (shape.size() != strides.size())
466+
if (ndim != strides.size())
464467
pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
465468
auto descr = dt;
466469

@@ -475,9 +478,7 @@ class array : public buffer {
475478
}
476479

477480
auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
478-
api.PyArray_Type_, descr.release().ptr(), (int) ndim,
479-
reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(shape.data())),
480-
reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(strides.data())),
481+
api.PyArray_Type_, descr.release().ptr(), (int) ndim, shape.data(), strides.data(),
481482
const_cast<void *>(ptr), flags, nullptr));
482483
if (!tmp)
483484
pybind11_fail("NumPy: unable to create array!");
@@ -491,27 +492,59 @@ class array : public buffer {
491492
m_ptr = tmp.release().ptr();
492493
}
493494

494-
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
495+
template <typename Shape, typename Strides,
496+
typename = decltype(std::begin(std::declval<const Shape &>())),
497+
typename = decltype(std::begin(std::declval<const Strides &>()))>
498+
array(const pybind11::dtype &dt, const Shape &shape, const Strides &strides,
495499
const void *ptr = nullptr, handle base = handle())
496-
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
500+
: array(dt, std::begin(shape), std::end(shape), std::begin(strides), std::end(strides), ptr, base) { }
497501

498-
array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
499-
handle base = handle())
500-
: array(dt, std::vector<size_t>{ count }, ptr, base) { }
502+
template <typename TSh, typename TSt>
503+
array(const pybind11::dtype &dt, const std::initializer_list<TSh> &shape, const std::initializer_list<TSt> &strides,
504+
const void *ptr = nullptr, handle base = handle())
505+
: array(dt, std::begin(shape), std::end(shape), std::begin(strides), std::end(strides), ptr, base) { }
501506

502-
template<typename T> array(const std::vector<size_t>& shape,
503-
const std::vector<size_t>& strides,
504-
const T* ptr, handle base = handle())
505-
: array(pybind11::dtype::of<T>(), shape, strides, (const void *) ptr, base) { }
507+
template <typename Shape, typename = decltype(std::begin(std::declval<const Shape &>()))>
508+
array(const pybind11::dtype &dt, const Shape &shape, const void *ptr = nullptr, handle base = handle())
509+
: array(dt, shape, default_strides({std::begin(shape), std::end(shape)}, dt.itemsize()), ptr, base) { }
506510

507-
template <typename T>
508-
array(const std::vector<size_t> &shape, const T *ptr,
511+
template <typename TSh>
512+
array(const pybind11::dtype &dt, const std::initializer_list<TSh> &shape, const void *ptr = nullptr, handle base = handle())
513+
: array(dt, shape, default_strides({std::begin(shape), std::end(shape)}, dt.itemsize()), ptr, base) { }
514+
515+
array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
509516
handle base = handle())
510-
: array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
517+
: array(dt, { count }, ptr, base) { }
518+
519+
template <typename T, typename ShapeIt, typename StridesIt,
520+
typename = detail::enable_if_t<detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
521+
array(ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
522+
const T *ptr = nullptr, handle base = handle())
523+
: array(pybind11::dtype::of<T>(), std::move(shape_first), std::move(shape_last),
524+
std::move(strides_first), std::move(strides_last), ptr, base) { }
525+
526+
template <typename T, typename Shape, typename Strides,
527+
typename = decltype(std::begin(std::declval<const Shape &>())),
528+
typename = decltype(std::begin(std::declval<const Strides &>()))>
529+
array(const Shape &shape, const Strides &strides, const T *ptr, handle base = handle())
530+
: array(pybind11::dtype::of<T>(), shape, strides, ptr, base) { }
531+
532+
template <typename T, typename TSh, typename TSt>
533+
array(const std::initializer_list<TSh> &shape, const std::initializer_list<TSt> &strides,
534+
const T *ptr, handle base = handle())
535+
: array(pybind11::dtype::of<T>(), shape, strides, ptr, base) { }
536+
537+
template <typename T, typename Shape, typename = decltype(std::begin(std::declval<const Shape &>()))>
538+
array(const Shape &shape, const T *ptr, handle base = handle())
539+
: array(shape, default_strides({std::begin(shape), std::end(shape)}, sizeof(T)), ptr, base) { }
540+
541+
template <typename T, typename TSh>
542+
array(const std::initializer_list<TSh> &shape, const T *ptr, handle base = handle())
543+
: array(shape, default_strides({std::begin(shape), std::end(shape)}, sizeof(T)), ptr, base) { }
511544

512545
template <typename T>
513546
array(size_t count, const T *ptr, handle base = handle())
514-
: array(std::vector<size_t>{ count }, ptr, base) { }
547+
: array({ count }, ptr, base) { }
515548

516549
explicit array(const buffer_info &info)
517550
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
@@ -729,13 +762,29 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
729762

730763
explicit array_t(const buffer_info& info) : array(info) { }
731764

732-
array_t(const std::vector<size_t> &shape,
733-
const std::vector<size_t> &strides, const T *ptr = nullptr,
734-
handle base = handle())
765+
template <typename ShapeIt, typename StridesIt,
766+
typename = detail::enable_if_t<detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
767+
array_t(ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
768+
const T *ptr = nullptr, handle base = handle())
769+
: array(std::move(shape_first), std::move(shape_last), std::move(strides_first), std::move(strides_last), ptr, base) { }
770+
771+
template <typename Shape, typename Strides,
772+
typename = decltype(std::begin(std::declval<const Shape &>())),
773+
typename = decltype(std::begin(std::declval<const Strides &>()))>
774+
array_t(const Shape &shape, const Strides &strides, const T *ptr = nullptr, handle base = handle())
775+
: array(shape, strides, ptr, base) { }
776+
777+
template <typename TSh, typename TSt>
778+
array_t(const std::initializer_list<TSh> &shape, const std::initializer_list<TSt> &strides,
779+
const T *ptr = nullptr, handle base = handle())
735780
: array(shape, strides, ptr, base) { }
736781

737-
explicit array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
738-
handle base = handle())
782+
template <typename Shape, typename = decltype(std::begin(std::declval<const Shape &>()))>
783+
explicit array_t(const Shape &shape, const T *ptr = nullptr, handle base = handle())
784+
: array(shape, ptr, base) { }
785+
786+
template <typename TSh>
787+
explicit array_t(const std::initializer_list<TSh> &shape, const T *ptr = nullptr, handle base = handle())
739788
: array(shape, ptr, base) { }
740789

741790
explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())

tests/test_numpy_array.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include <pybind11/stl.h>
1414

1515
#include <cstdint>
16-
#include <vector>
1716

1817
using arr = py::array;
1918
using arr_t = py::array_t<uint16_t, 0>;
@@ -119,8 +118,8 @@ test_initializer numpy_array([](py::module &m) {
119118
sm.def("wrap", [](py::array a) {
120119
return py::array(
121120
a.dtype(),
122-
std::vector<size_t>(a.shape(), a.shape() + a.ndim()),
123-
std::vector<size_t>(a.strides(), a.strides() + a.ndim()),
121+
a.shape(), a.shape() + a.ndim(),
122+
a.strides(), a.strides() + a.ndim(),
124123
a.data(),
125124
a
126125
);

0 commit comments

Comments
 (0)