Skip to content

Commit 9a538ca

Browse files
committed
safe/unsafe access policy
1 parent 7830e85 commit 9a538ca

File tree

1 file changed

+82
-47
lines changed

1 file changed

+82
-47
lines changed

include/pybind11/numpy.h

Lines changed: 82 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,64 @@ class dtype : public object {
316316
}
317317
};
318318

319-
class array : public buffer {
319+
namespace detail {
320+
inline void fail_dim_check(size_t dim, size_t ndim, const std::string& msg) {
321+
throw index_error(msg + ": " + std::to_string(dim) +
322+
" (ndim = " + std::to_string(ndim) + ")");
323+
}
324+
}
325+
326+
class safe_access_policy {
327+
public:
328+
void check_axis(size_t dim, size_t ndim) const {
329+
if(dim >= ndim) {
330+
detail::fail_dim_check(dim, ndim, "invalid axis");
331+
}
332+
}
333+
334+
template <typename... Ix>
335+
void check_indices(size_t ndim, Ix...) const
336+
{
337+
if(sizeof...(Ix) > ndim) {
338+
detail::fail_dim_check(sizeof...(Ix), ndim, "too many indices for an array");
339+
}
340+
}
341+
342+
template<typename... Ix>
343+
void check_dimensions(const size_t* shape, Ix... index) const {
344+
check_dimensions_impl(size_t(0), shape, size_t(index)...);
345+
}
346+
347+
private:
348+
void check_dimensions_impl(size_t, const size_t*) const { }
349+
350+
template<typename... Ix>
351+
void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const {
352+
if (i >= *shape) {
353+
throw index_error(std::string("index ") + std::to_string(i) +
354+
" is out of bounds for axis " + std::to_string(axis) +
355+
" with size " + std::to_string(*shape));
356+
}
357+
check_dimensions_impl(axis + 1, shape + 1, index...);
358+
}
359+
};
360+
361+
class unsafe_access_policy {
362+
public:
363+
void check_axis(size_t, size_t) const {
364+
}
365+
366+
template <typename... Ix>
367+
void check_indices(size_t, Ix...) const {
368+
}
369+
370+
template <typename... Ix>
371+
void check_dimensions(const size_t*, Ix...) const {
372+
}
373+
};
374+
375+
template <class access_policy = safe_access_policy>
376+
class array : public buffer, private access_policy {
320377
public:
321378
PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
322379

@@ -424,8 +481,7 @@ class array : public buffer {
424481

425482
/// Dimension along a given axis
426483
size_t shape(size_t dim) const {
427-
if (dim >= ndim())
428-
fail_dim_check(dim, "invalid axis");
484+
access_policy::check_axis(dim, ndim());
429485
return shape()[dim];
430486
}
431487

@@ -436,8 +492,7 @@ class array : public buffer {
436492

437493
/// Stride along a given axis
438494
size_t strides(size_t dim) const {
439-
if (dim >= ndim())
440-
fail_dim_check(dim, "invalid axis");
495+
access_policy::check_axis(dim, ndim());
441496
return strides()[dim];
442497
}
443498

@@ -473,8 +528,7 @@ class array : public buffer {
473528
/// Byte offset from beginning of the array to a given index (full or partial).
474529
/// May throw if the index would lead to out of bounds access.
475530
template<typename... Ix> size_t offset_at(Ix... index) const {
476-
if (sizeof...(index) > ndim())
477-
fail_dim_check(sizeof...(index), "too many indices for an array");
531+
access_policy::check_indices(ndim(), index...);
478532
return byte_offset(size_t(index)...);
479533
}
480534

@@ -504,13 +558,8 @@ class array : public buffer {
504558
protected:
505559
template<typename, typename> friend struct detail::npy_format_descriptor;
506560

507-
void fail_dim_check(size_t dim, const std::string& msg) const {
508-
throw index_error(msg + ": " + std::to_string(dim) +
509-
" (ndim = " + std::to_string(ndim()) + ")");
510-
}
511-
512561
template<typename... Ix> size_t byte_offset(Ix... index) const {
513-
check_dimensions(index...);
562+
access_policy::check_dimensions(shape(), index...);
514563
return byte_offset_unsafe(index...);
515564
}
516565

@@ -537,21 +586,6 @@ class array : public buffer {
537586
return strides;
538587
}
539588

540-
template<typename... Ix> void check_dimensions(Ix... index) const {
541-
check_dimensions_impl(size_t(0), shape(), size_t(index)...);
542-
}
543-
544-
void check_dimensions_impl(size_t, const size_t*) const { }
545-
546-
template<typename... Ix> void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const {
547-
if (i >= *shape) {
548-
throw index_error(std::string("index ") + std::to_string(i) +
549-
" is out of bounds for axis " + std::to_string(axis) +
550-
" with size " + std::to_string(*shape));
551-
}
552-
check_dimensions_impl(axis + 1, shape + 1, index...);
553-
}
554-
555589
/// Create array from any object -- always returns a new reference
556590
static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
557591
if (ptr == nullptr)
@@ -561,35 +595,36 @@ class array : public buffer {
561595
}
562596
};
563597

564-
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
598+
template <typename T, int ExtraFlags = array<>::forcecast> class array_t : public array<> {
565599
public:
566-
array_t() : array(0, static_cast<const T *>(nullptr)) {}
567-
array_t(handle h, borrowed_t) : array(h, borrowed) { }
568-
array_t(handle h, stolen_t) : array(h, stolen) { }
600+
using base_type = array<>;
601+
array_t() : base_type(0, static_cast<const T *>(nullptr)) {}
602+
array_t(handle h, borrowed_t) : base_type(h, borrowed) { }
603+
array_t(handle h, stolen_t) : base_type(h, stolen) { }
569604

570605
PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
571-
array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen) {
606+
array_t(handle h, bool is_borrowed) : base_type(raw_array_t(h.ptr()), stolen) {
572607
if (!m_ptr) PyErr_Clear();
573608
if (!is_borrowed) Py_XDECREF(h.ptr());
574609
}
575610

576-
array_t(const object &o) : array(raw_array_t(o.ptr()), stolen) {
611+
array_t(const object &o) : base_type(raw_array_t(o.ptr()), stolen) {
577612
if (!m_ptr) throw error_already_set();
578613
}
579614

580-
explicit array_t(const buffer_info& info) : array(info) { }
615+
explicit array_t(const buffer_info& info) : base_type(info) { }
581616

582617
array_t(const std::vector<size_t> &shape,
583618
const std::vector<size_t> &strides, const T *ptr = nullptr,
584619
handle base = handle())
585-
: array(shape, strides, ptr, base) { }
620+
: base_type(shape, strides, ptr, base) { }
586621

587622
explicit array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
588623
handle base = handle())
589-
: array(shape, ptr, base) { }
624+
: base_type(shape, ptr, base) { }
590625

591626
explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
592-
: array(count, ptr, base) { }
627+
: base_type(count, ptr, base) { }
593628

594629
constexpr size_t itemsize() const {
595630
return sizeof(T);
@@ -600,25 +635,25 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
600635
}
601636

602637
template<typename... Ix> const T* data(Ix... index) const {
603-
return static_cast<const T*>(array::data(index...));
638+
return static_cast<const T*>(base_type::data(index...));
604639
}
605640

606641
template<typename... Ix> T* mutable_data(Ix... index) {
607-
return static_cast<T*>(array::mutable_data(index...));
642+
return static_cast<T*>(base_type::mutable_data(index...));
608643
}
609644

610645
// Reference to element at a given index
611646
template<typename... Ix> const T& at(Ix... index) const {
612647
if (sizeof...(index) != ndim())
613-
fail_dim_check(sizeof...(index), "index dimension mismatch");
614-
return *(static_cast<const T*>(array::data()) + byte_offset(size_t(index)...) / itemsize());
648+
detail::fail_dim_check(sizeof...(index), ndim(), "index dimension mismatch");
649+
return *(static_cast<const T*>(base_type::data()) + byte_offset(size_t(index)...) / itemsize());
615650
}
616651

617652
// Mutable reference to element at a given index
618653
template<typename... Ix> T& mutable_at(Ix... index) {
619654
if (sizeof...(index) != ndim())
620-
fail_dim_check(sizeof...(index), "index dimension mismatch");
621-
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
655+
detail::fail_dim_check(sizeof...(index), ndim(), "index dimension mismatch");
656+
return *(static_cast<T*>(base_type::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
622657
}
623658

624659
/// Ensure that the argument is a NumPy array of the correct dtype.
@@ -811,7 +846,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
811846

812847
// Sanity check: verify that NumPy properly parses our buffer format string
813848
auto& api = npy_api::get();
814-
auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
849+
auto arr = array<>(buffer_info(nullptr, itemsize, format_str, 1));
815850
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
816851
pybind11_fail("NumPy: invalid buffer descriptor!");
817852

@@ -1076,11 +1111,11 @@ struct vectorize_helper {
10761111
template <typename T>
10771112
explicit vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
10781113

1079-
object operator()(array_t<Args, array::c_style | array::forcecast>... args) {
1114+
object operator()(array_t<Args, array<>::c_style | array<>::forcecast>... args) {
10801115
return run(args..., make_index_sequence<sizeof...(Args)>());
10811116
}
10821117

1083-
template <size_t ... Index> object run(array_t<Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) {
1118+
template <size_t ... Index> object run(array_t<Args, array<>::c_style | array<>::forcecast>&... args, index_sequence<Index...> index) {
10841119
/* Request buffers from all parameters */
10851120
const size_t N = sizeof...(Args);
10861121

0 commit comments

Comments
 (0)