Skip to content

Commit 177b631

Browse files
committed
safe/unsafe access policy
1 parent 7830e85 commit 177b631

File tree

1 file changed

+104
-66
lines changed

1 file changed

+104
-66
lines changed

include/pybind11/numpy.h

Lines changed: 104 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -316,19 +316,75 @@ class dtype : public object {
316316
}
317317
};
318318

319-
class array : public buffer {
319+
NAMESPACE_BEGIN(detail)
320+
[[noreturn]] PYBIND11_NOINLINE inline void fail_dim_check(size_t dim, size_t ndim, const std::string& msg) {
321+
throw index_error(msg + ": " + std::to_string(dim) + " (ndim = " + std::to_string(ndim) + ")");
322+
}
323+
NAMESPACE_END(detail)
324+
325+
class safe_access_policy {
326+
public:
327+
void check_axis(size_t dim, size_t ndim) const {
328+
if(dim >= ndim) {
329+
detail::fail_dim_check(dim, ndim, "invalid axis");
330+
}
331+
}
332+
333+
template <typename... Ix>
334+
void check_indices(size_t ndim, Ix...) const
335+
{
336+
if(sizeof...(Ix) > ndim) {
337+
detail::fail_dim_check(sizeof...(Ix), ndim, "too many indices for an array");
338+
}
339+
}
340+
341+
template<typename... Ix>
342+
void check_dimensions(const size_t* shape, Ix... index) const {
343+
check_dimensions_impl(size_t(0), shape, size_t(index)...);
344+
}
345+
346+
private:
347+
void check_dimensions_impl(size_t, const size_t*) const { }
348+
349+
template<typename... Ix>
350+
void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const {
351+
if (i >= *shape) {
352+
throw index_error(std::string("index ") + std::to_string(i) +
353+
" is out of bounds for axis " + std::to_string(axis) +
354+
" with size " + std::to_string(*shape));
355+
}
356+
check_dimensions_impl(axis + 1, shape + 1, index...);
357+
}
358+
};
359+
360+
class unsafe_access_policy {
361+
public:
362+
void check_axis(size_t, size_t) const {
363+
}
364+
365+
template <typename... Ix>
366+
void check_indices(size_t, Ix...) const {
367+
}
368+
369+
template <typename... Ix>
370+
void check_dimensions(const size_t*, Ix...) const {
371+
}
372+
};
373+
374+
template <class access_policy = safe_access_policy>
375+
class array_base : public buffer, private access_policy {
320376
public:
321-
PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
377+
PYBIND11_OBJECT_CVT(array_base, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
322378

323379
enum {
324380
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
325381
f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
326382
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
327383
};
328384

329-
array() : array(0, static_cast<const double *>(nullptr)) {}
385+
array_base() : array_base(0, static_cast<const double *>(nullptr)) {}
330386

331-
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
387+
array_base(const pybind11::dtype &dt, const std::vector<size_t> &shape,
332388
const std::vector<size_t> &strides, const void *ptr = nullptr,
333389
handle base = handle()) {
334390
auto& api = detail::npy_api::get();
@@ -362,30 +418,30 @@ class array : public buffer {
362418
m_ptr = tmp.release().ptr();
363419
}
364420

365-
array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
421+
array_base(const pybind11::dtype &dt, const std::vector<size_t> &shape,
366422
const void *ptr = nullptr, handle base = handle())
367-
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
423+
: array_base(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
368424

369-
array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
425+
array_base(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
370426
handle base = handle())
371-
: array(dt, std::vector<size_t>{ count }, ptr, base) { }
427+
: array_base(dt, std::vector<size_t>{ count }, ptr, base) { }
372428

373-
template<typename T> array(const std::vector<size_t>& shape,
429+
template<typename T> array_base(const std::vector<size_t>& shape,
374430
const std::vector<size_t>& strides,
375431
const T* ptr, handle base = handle())
376-
: array(pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
432+
: array_base(pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
377433

378434
template <typename T>
379-
array(const std::vector<size_t> &shape, const T *ptr,
435+
array_base(const std::vector<size_t> &shape, const T *ptr,
380436
handle base = handle())
381-
: array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
437+
: array_base(shape, default_strides(shape, sizeof(T)), ptr, base) { }
382438

383439
template <typename T>
384-
array(size_t count, const T *ptr, handle base = handle())
385-
: array(std::vector<size_t>{ count }, ptr, base) { }
440+
array_base(size_t count, const T *ptr, handle base = handle())
441+
: array_base(std::vector<size_t>{ count }, ptr, base) { }
386442

387-
explicit array(const buffer_info &info)
388-
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
443+
explicit array_base(const buffer_info &info)
444+
: array_base(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
389445

390446
/// Array descriptor (dtype)
391447
pybind11::dtype dtype() const {
@@ -424,8 +480,7 @@ class array : public buffer {
424480

425481
/// Dimension along a given axis
426482
size_t shape(size_t dim) const {
427-
if (dim >= ndim())
428-
fail_dim_check(dim, "invalid axis");
483+
access_policy::check_axis(dim, ndim());
429484
return shape()[dim];
430485
}
431486

@@ -436,8 +491,7 @@ class array : public buffer {
436491

437492
/// Stride along a given axis
438493
size_t strides(size_t dim) const {
439-
if (dim >= ndim())
440-
fail_dim_check(dim, "invalid axis");
494+
access_policy::check_axis(dim, ndim());
441495
return strides()[dim];
442496
}
443497

@@ -473,8 +527,7 @@ class array : public buffer {
473527
/// Byte offset from beginning of the array to a given index (full or partial).
474528
/// May throw if the index would lead to out of bounds access.
475529
template<typename... Ix> size_t offset_at(Ix... index) const {
476-
if (sizeof...(index) > ndim())
477-
fail_dim_check(sizeof...(index), "too many indices for an array");
530+
access_policy::check_indices(ndim(), index...);
478531
return byte_offset(size_t(index)...);
479532
}
480533

@@ -487,15 +540,15 @@ class array : public buffer {
487540
}
488541

489542
/// Return a new view with all of the dimensions of length 1 removed
490-
array squeeze() {
543+
array_base squeeze() {
491544
auto& api = detail::npy_api::get();
492-
return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
545+
return reinterpret_steal<array_base>(api.PyArray_Squeeze_(m_ptr));
493546
}
494547

495548
/// Ensure that the argument is a NumPy array
496549
/// In case of an error, nullptr is returned and the Python error is cleared.
497-
static array ensure(handle h, int ExtraFlags = 0) {
498-
auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
550+
static array_base ensure(handle h, int ExtraFlags = 0) {
551+
auto result = reinterpret_steal<array_base>(raw_array(h.ptr(), ExtraFlags));
499552
if (!result)
500553
PyErr_Clear();
501554
return result;
@@ -504,13 +557,8 @@ class array : public buffer {
504557
protected:
505558
template<typename, typename> friend struct detail::npy_format_descriptor;
506559

507-
void fail_dim_check(size_t dim, const std::string& msg) const {
508-
throw index_error(msg + ": " + std::to_string(dim) +
509-
" (ndim = " + std::to_string(ndim()) + ")");
510-
}
511-
512560
template<typename... Ix> size_t byte_offset(Ix... index) const {
513-
check_dimensions(index...);
561+
access_policy::check_dimensions(shape(), index...);
514562
return byte_offset_unsafe(index...);
515563
}
516564

@@ -537,21 +585,6 @@ class array : public buffer {
537585
return strides;
538586
}
539587

540-
template<typename... Ix> void check_dimensions(Ix... index) const {
541-
check_dimensions_impl(size_t(0), shape(), size_t(index)...);
542-
}
543-
544-
void check_dimensions_impl(size_t, const size_t*) const { }
545-
546-
template<typename... Ix> void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const {
547-
if (i >= *shape) {
548-
throw index_error(std::string("index ") + std::to_string(i) +
549-
" is out of bounds for axis " + std::to_string(axis) +
550-
" with size " + std::to_string(*shape));
551-
}
552-
check_dimensions_impl(axis + 1, shape + 1, index...);
553-
}
554-
555588
/// Create array from any object -- always returns a new reference
556589
static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
557590
if (ptr == nullptr)
@@ -561,35 +594,40 @@ class array : public buffer {
561594
}
562595
};
563596

564-
template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
597+
using array = array_base<safe_access_policy>;
598+
using array_unchecked = array_base<unsafe_access_policy>;
599+
600+
template <typename T, int ExtraFlags = array_base<>::forcecast, class access_policy = safe_access_policy>
601+
class array_t : public array_base<access_policy> {
565602
public:
566-
array_t() : array(0, static_cast<const T *>(nullptr)) {}
567-
array_t(handle h, borrowed_t) : array(h, borrowed) { }
568-
array_t(handle h, stolen_t) : array(h, stolen) { }
603+
using base_type = array_base<access_policy>;
604+
array_t() : base_type(0, static_cast<const T *>(nullptr)) {}
605+
array_t(handle h, borrowed_t) : base_type(h, borrowed) { }
606+
array_t(handle h, stolen_t) : base_type(h, stolen) { }
569607

570608
PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
571-
array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen) {
609+
array_t(handle h, bool is_borrowed) : base_type(raw_array_t(h.ptr()), stolen) {
572610
if (!m_ptr) PyErr_Clear();
573611
if (!is_borrowed) Py_XDECREF(h.ptr());
574612
}
575613

576-
array_t(const object &o) : array(raw_array_t(o.ptr()), stolen) {
614+
array_t(const object &o) : base_type(raw_array_t(o.ptr()), stolen) {
577615
if (!m_ptr) throw error_already_set();
578616
}
579617

580-
explicit array_t(const buffer_info& info) : array(info) { }
618+
explicit array_t(const buffer_info& info) : base_type(info) { }
581619

582620
array_t(const std::vector<size_t> &shape,
583621
const std::vector<size_t> &strides, const T *ptr = nullptr,
584622
handle base = handle())
585-
: array(shape, strides, ptr, base) { }
623+
: base_type(shape, strides, ptr, base) { }
586624

587625
explicit array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
588626
handle base = handle())
589-
: array(shape, ptr, base) { }
627+
: base_type(shape, ptr, base) { }
590628

591629
explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
592-
: array(count, ptr, base) { }
630+
: base_type(count, ptr, base) { }
593631

594632
constexpr size_t itemsize() const {
595633
return sizeof(T);
@@ -600,25 +638,25 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
600638
}
601639

602640
template<typename... Ix> const T* data(Ix... index) const {
603-
return static_cast<const T*>(array::data(index...));
641+
return static_cast<const T*>(base_type::data(index...));
604642
}
605643

606644
template<typename... Ix> T* mutable_data(Ix... index) {
607-
return static_cast<T*>(array::mutable_data(index...));
645+
return static_cast<T*>(base_type::mutable_data(index...));
608646
}
609647

610648
// Reference to element at a given index
611649
template<typename... Ix> const T& at(Ix... index) const {
612650
if (sizeof...(index) != ndim())
613-
fail_dim_check(sizeof...(index), "index dimension mismatch");
614-
return *(static_cast<const T*>(array::data()) + byte_offset(size_t(index)...) / itemsize());
651+
detail::fail_dim_check(sizeof...(index), ndim(), "index dimension mismatch");
652+
return *(static_cast<const T*>(base_type::data()) + byte_offset(size_t(index)...) / itemsize());
615653
}
616654

617655
// Mutable reference to element at a given index
618656
template<typename... Ix> T& mutable_at(Ix... index) {
619657
if (sizeof...(index) != ndim())
620-
fail_dim_check(sizeof...(index), "index dimension mismatch");
621-
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
658+
detail::fail_dim_check(sizeof...(index), ndim(), "index dimension mismatch");
659+
return *(static_cast<T*>(base_type::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
622660
}
623661

624662
/// Ensure that the argument is a NumPy array of the correct dtype.
@@ -811,7 +849,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
811849

812850
// Sanity check: verify that NumPy properly parses our buffer format string
813851
auto& api = npy_api::get();
814-
auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
852+
auto arr = array_base<>(buffer_info(nullptr, itemsize, format_str, 1));
815853
if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
816854
pybind11_fail("NumPy: invalid buffer descriptor!");
817855

@@ -1076,11 +1114,11 @@ struct vectorize_helper {
10761114
template <typename T>
10771115
explicit vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
10781116

1079-
object operator()(array_t<Args, array::c_style | array::forcecast>... args) {
1117+
object operator()(array_t<Args, array_base<>::c_style | array_base<>::forcecast>... args) {
10801118
return run(args..., make_index_sequence<sizeof...(Args)>());
10811119
}
10821120

1083-
template <size_t ... Index> object run(array_t<Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) {
1121+
template <size_t ... Index> object run(array_t<Args, array_base<>::c_style | array_base<>::forcecast>&... args, index_sequence<Index...> index) {
10841122
/* Request buffers from all parameters */
10851123
const size_t N = sizeof...(Args);
10861124

0 commit comments

Comments
 (0)