Skip to content

Commit ad75327

Browse files
committed
Accept abitrary containers and iterators for shape/strides
This adds support for constructing `buffer_info` and `array`s using arbitrary containers or iterators instead of requiring a vector. This is primarily needed by PR pybind#782 (which makes strides signed to properly support negative strides), but also needs to preserve backwards compatibility with 2.1 and earlier which accepts the strides parameter as a vector of size_t's. Rather than adding nearly duplicate constructors for each stride-taking constructor, it seems nicer to simply allow any type of container (or iterator pairs). This adds iterator pair constructors, and also adds a `detail::any_container` class that handles implicit conversion of arbitrary containers into a vector of the desired type.
1 parent 83d1a88 commit ad75327

File tree

5 files changed

+125
-60
lines changed

5 files changed

+125
-60
lines changed

include/pybind11/buffer_info.h

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
BSD-style license that can be found in the LICENSE file.
88
*/
99

10-
#pragma once
10+
#pragma once
1111

1212
#include "common.h"
1313

@@ -26,25 +26,29 @@ struct buffer_info {
2626
buffer_info() { }
2727

2828
buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t ndim,
29-
const std::vector<size_t> &shape, const std::vector<size_t> &strides)
30-
: ptr(ptr), itemsize(itemsize), size(1), format(format),
31-
ndim(ndim), shape(shape), strides(strides) {
29+
detail::any_container<size_t> shape_in, detail::any_container<size_t> strides_in)
30+
: ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
31+
shape(std::move(shape_in)), strides(std::move(strides_in)) {
32+
if (ndim != shape.size() || ndim != strides.size())
33+
pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length");
3234
for (size_t i = 0; i < ndim; ++i)
3335
size *= shape[i];
3436
}
3537

38+
39+
template <typename ShapeIt, typename StridesIt,
40+
typename = detail::enable_if_t<detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
41+
buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t ndim,
42+
ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last)
43+
: buffer_info(ptr, itemsize, format, ndim, {shape_first, shape_last}, {strides_first, strides_last}) { }
44+
3645
buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t size)
37-
: buffer_info(ptr, itemsize, format, 1, std::vector<size_t> { size },
38-
std::vector<size_t> { itemsize }) { }
39-
40-
explicit buffer_info(Py_buffer *view, bool ownview = true)
41-
: ptr(view->buf), itemsize((size_t) view->itemsize), size(1), format(view->format),
42-
ndim((size_t) view->ndim), shape((size_t) view->ndim), strides((size_t) view->ndim), view(view), ownview(ownview) {
43-
for (size_t i = 0; i < (size_t) view->ndim; ++i) {
44-
shape[i] = (size_t) view->shape[i];
45-
strides[i] = (size_t) view->strides[i];
46-
size *= shape[i];
47-
}
46+
: buffer_info(ptr, itemsize, format, 1, { size }, { itemsize }) { }
47+
48+
explicit buffer_info(Py_buffer *view, bool ownview_in = true)
49+
: buffer_info(view->buf, (size_t) view->itemsize, view->format, (size_t) view->ndim,
50+
view->shape, view->shape + view->ndim, view->strides, view->strides + view->ndim) {
51+
ownview = ownview_in;
4852
}
4953

5054
buffer_info(const buffer_info &) = delete;

include/pybind11/common.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,12 @@ struct is_instantiation<Class, Class<Us...>> : std::true_type { };
490490
/// Check if T is std::shared_ptr<U> where U can be anything
491491
template <typename T> using is_shared_ptr = is_instantiation<std::shared_ptr, T>;
492492

493+
/// Check if T looks like an input iterator
494+
template <typename T, typename = void> struct is_input_iterator : std::false_type {};
495+
template <typename T>
496+
struct is_input_iterator<T, void_t<decltype(*std::declval<T &>()), decltype(++std::declval<T &>())>>
497+
: std::true_type {};
498+
493499
/// Ignore that a variable is unused in compiler warnings
494500
inline void ignore_unused(const int *) { }
495501

@@ -651,4 +657,46 @@ static constexpr auto const_ = std::true_type{};
651657

652658
#endif // overload_cast
653659

660+
NAMESPACE_BEGIN(detail)
661+
662+
// Adaptor for converting arbitrary container arguments into a vector; implicitly convertible from
663+
// any standard container (or C-style array) supporting std::begin/std::end.
664+
template <typename T>
665+
class any_container {
666+
std::vector<T> v;
667+
public:
668+
any_container() = default;
669+
670+
// Can construct from a pair of iterators
671+
template <typename It, typename = enable_if_t<is_input_iterator<It>::value>>
672+
any_container(It first, It last) : v(first, last) { }
673+
674+
// Implicit conversion constructor from any arbitrary container type with values convertible to T
675+
template <typename Container, typename = enable_if_t<std::is_convertible<decltype(*std::begin(std::declval<const Container &>())), T>::value>>
676+
any_container(const Container &c) : any_container(std::begin(c), std::end(c)) { }
677+
678+
// initializer_list's aren't deducible, so don't get matched by the above template; we need this
679+
// to explicitly allow implicit conversion from one:
680+
template <typename TIn, typename = enable_if_t<std::is_convertible<TIn, T>::value>>
681+
any_container(const std::initializer_list<TIn> &c) : any_container(c.begin(), c.end()) { }
682+
683+
// Avoid copying if given an rvalue vector of the correct type.
684+
any_container(std::vector<T> &&v) : v(std::move(v)) { }
685+
686+
// Moves the vector out of an rvalue any_container
687+
operator std::vector<T> &&() && { return std::move(v); }
688+
689+
// Dereferencing obtains a reference to the underlying vector
690+
std::vector<T> &operator*() { return v; }
691+
const std::vector<T> &operator*() const { return v; }
692+
693+
// -> lets you call methods on the underlying vector
694+
std::vector<T> *operator->() { return &v; }
695+
const std::vector<T> *operator->() const { return &v; }
696+
};
697+
698+
NAMESPACE_END(detail)
699+
700+
701+
654702
NAMESPACE_END(pybind11)

include/pybind11/eigen.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -201,18 +201,13 @@ template <typename Type_> struct EigenProps {
201201
// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
202202
template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
203203
constexpr size_t elem_size = sizeof(typename props::Scalar);
204-
std::vector<size_t> shape, strides;
205-
if (props::vector) {
206-
shape.push_back(src.size());
207-
strides.push_back(elem_size * src.innerStride());
208-
}
209-
else {
210-
shape.push_back(src.rows());
211-
shape.push_back(src.cols());
212-
strides.push_back(elem_size * src.rowStride());
213-
strides.push_back(elem_size * src.colStride());
214-
}
215-
array a(std::move(shape), std::move(strides), src.data(), base);
204+
array a;
205+
if (props::vector)
206+
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
207+
else
208+
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
209+
src.data(), base);
210+
216211
if (!writeable)
217212
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
218213

include/pybind11/numpy.h

Lines changed: 49 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -455,12 +455,18 @@ class array : public buffer {
455455

456456
array() : array(0, static_cast<const double *>(nullptr)) {}
457457

458-
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
459-
const std::vector<size_t> &strides, const void *ptr = nullptr,
460-
handle base = handle()) {
461-
auto& api = detail::npy_api::get();
462-
auto ndim = shape.size();
463-
if (shape.size() != strides.size())
458+
using ShapeContainer = detail::any_container<Py_intptr_t>;
459+
using StridesContainer = detail::any_container<Py_intptr_t>;
460+
461+
// Constructs an array taking shape/strides from arbitrary container types
462+
array(const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides,
463+
const void *ptr = nullptr, handle base = handle()) {
464+
465+
if (strides->empty())
466+
strides = default_strides(*shape, dt.itemsize());
467+
468+
auto ndim = shape->size();
469+
if (ndim != strides->size())
464470
pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
465471
auto descr = dt;
466472

@@ -474,10 +480,9 @@ class array : public buffer {
474480
flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
475481
}
476482

483+
auto &api = detail::npy_api::get();
477484
auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
478-
api.PyArray_Type_, descr.release().ptr(), (int) ndim,
479-
reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(shape.data())),
480-
reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(strides.data())),
485+
api.PyArray_Type_, descr.release().ptr(), (int) ndim, shape->data(), strides->data(),
481486
const_cast<void *>(ptr), flags, nullptr));
482487
if (!tmp)
483488
pybind11_fail("NumPy: unable to create array!");
@@ -491,27 +496,37 @@ class array : public buffer {
491496
m_ptr = tmp.release().ptr();
492497
}
493498

494-
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
495-
const void *ptr = nullptr, handle base = handle())
496-
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
499+
template <typename ShapeIt, typename StridesIt,
500+
typename = detail::enable_if_t<detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
501+
array(const pybind11::dtype &dt, ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
502+
const void *ptr = nullptr, handle base = handle())
503+
: array(dt, {shape_first, shape_last}, {strides_first, strides_last}, ptr, base) { }
504+
505+
array(const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr, handle base = handle())
506+
: array(dt, std::move(shape), {}, ptr, base) { }
497507

498508
array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
499509
handle base = handle())
500-
: array(dt, std::vector<size_t>{ count }, ptr, base) { }
510+
: array(dt, ShapeContainer{{ count }}, ptr, base) { }
501511

502-
template<typename T> array(const std::vector<size_t>& shape,
503-
const std::vector<size_t>& strides,
504-
const T* ptr, handle base = handle())
505-
: array(pybind11::dtype::of<T>(), shape, strides, (const void *) ptr, base) { }
512+
template <typename T, typename ShapeIt, typename StridesIt,
513+
typename = detail::enable_if_t<detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
514+
array(ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
515+
const T *ptr = nullptr, handle base = handle())
516+
: array(pybind11::dtype::of<T>(), ShapeContainer(std::move(shape_first), std::move(shape_last)),
517+
StrideContainer(std::move(strides_first), std::move(strides_last)), ptr, base) { }
506518

507519
template <typename T>
508-
array(const std::vector<size_t> &shape, const T *ptr,
509-
handle base = handle())
510-
: array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
520+
array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
521+
: array(pybind11::dtype::of<T>(), std::move(shape), std::move(strides), ptr, base) { }
522+
523+
template <typename T>
524+
array(ShapeContainer shape, const T *ptr, handle base = handle())
525+
: array(std::move(shape), {}, ptr, base) { }
511526

512527
template <typename T>
513528
array(size_t count, const T *ptr, handle base = handle())
514-
: array(std::vector<size_t>{ count }, ptr, base) { }
529+
: array({{ count }}, ptr, base) { }
515530

516531
explicit array(const buffer_info &info)
517532
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
@@ -673,9 +688,9 @@ class array : public buffer {
673688
throw std::domain_error("array is not writeable");
674689
}
675690

676-
static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
691+
static std::vector<Py_intptr_t> default_strides(const std::vector<Py_intptr_t>& shape, size_t itemsize) {
677692
auto ndim = shape.size();
678-
std::vector<size_t> strides(ndim);
693+
std::vector<Py_intptr_t> strides(ndim);
679694
if (ndim) {
680695
std::fill(strides.begin(), strides.end(), itemsize);
681696
for (size_t i = 0; i < ndim - 1; i++)
@@ -729,14 +744,18 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
729744

730745
explicit array_t(const buffer_info& info) : array(info) { }
731746

732-
array_t(const std::vector<size_t> &shape,
733-
const std::vector<size_t> &strides, const T *ptr = nullptr,
734-
handle base = handle())
735-
: array(shape, strides, ptr, base) { }
747+
array_t(ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr, handle base = handle())
748+
: array(std::move(shape), std::move(strides), ptr, base) { }
749+
750+
template <typename ShapeIt, typename StridesIt,
751+
typename = detail::enable_if_t<detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
752+
array_t(ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
753+
const T *ptr = nullptr, handle base = handle())
754+
: array(ShapeContainer(std::move(shape_first), std::move(shape_last)),
755+
StridesContainer(std::move(strides_first), std::move(strides_last)), ptr, base) { }
736756

737-
explicit array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
738-
handle base = handle())
739-
: array(shape, ptr, base) { }
757+
explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
758+
: array(std::move(shape), ptr, base) { }
740759

741760
explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
742761
: array(count, ptr, base) { }

tests/test_numpy_array.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include <pybind11/stl.h>
1414

1515
#include <cstdint>
16-
#include <vector>
1716

1817
using arr = py::array;
1918
using arr_t = py::array_t<uint16_t, 0>;
@@ -119,8 +118,8 @@ test_initializer numpy_array([](py::module &m) {
119118
sm.def("wrap", [](py::array a) {
120119
return py::array(
121120
a.dtype(),
122-
std::vector<size_t>(a.shape(), a.shape() + a.ndim()),
123-
std::vector<size_t>(a.strides(), a.strides() + a.ndim()),
121+
a.shape(), a.shape() + a.ndim(),
122+
a.strides(), a.strides() + a.ndim(),
124123
a.data(),
125124
a
126125
);

0 commit comments

Comments
 (0)