Skip to content

Add unchecked array access via proxy object #746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 22, 2017
41 changes: 41 additions & 0 deletions docs/advanced/pycpp/numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,44 @@ simply using ``vectorize``).

The file :file:`tests/test_numpy_vectorize.cpp` contains a complete
example that demonstrates using :func:`vectorize` in more detail.

Direct access
=============

For performance reasons, particularly when dealing with very large arrays, it
is often desirable to directly access array elements without internal checking
of dimensions and bounds on every access when indices are known to be already
valid. To avoid such checks, the ``array`` class and ``array_t<T>`` template
class offer an unchecked proxy object that can be used for this unchecked
access through the ``unchecked<N>`` and ``unchecked_readonly<N>`` methods,
where ``N`` gives the required dimensionality of the array:

.. code-block:: cpp

m.def("sum_3d", [](py::array_t<double> x) {
auto r = x.unchecked_readonly<3>(); // x must have ndim = 3; can be non-writeable
double sum = 0;
for (size_t i = 0; i < r.shape(0); i++)
for (size_t j = 0; j < r.shape(1); j++)
for (size_t k = 0; k < r.shape(2); k++)
sum += r(i, j, k);
return sum;
});
m.def("increment_3d", [](py::array_t<double> x) {
auto r = x.unchecked<3>(); // Will throw if ndim != 3 or flags.writeable is false
if (x.ndim() != 3)
throw std::runtime_error("error: 3D array required");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could be missing something, but doesn't the unchecked() call already throw under the same condition?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, yeah. That's leftover from earlier versions where it didn't.

for (size_t i = 0; i < r.shape(0); i++)
for (size_t j = 0; j < r.shape(1); j++)
for (size_t k = 0; k < r.shape(2); k++)
r(i, j, k) += 1.0;
});

To obtain the proxy from an ``array`` object, you must specify both the data
type and number of dimensions as a template argument, such as ``auto r =
myarray.unchecked<float, 2>()``.

Note that the returned proxy object directly references the array's data,
shape, and strides: you must take care to ensure that the referenced array
object is not destroyed or reshaped for the duration of the returned object,
typically by limiting the scope of the returned instance.
126 changes: 119 additions & 7 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,12 @@ template <typename T> using is_pod_struct = all_of<
satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
>;

template <size_t Dim = 0, typename Strides> size_t byte_offset_unsafe(const Strides &) { return 0; }
template <size_t Dim = 0, typename Strides, typename... Ix>
size_t byte_offset_unsafe(const Strides &strides, size_t i, Ix... index) {
return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
}

NAMESPACE_END(detail)

class dtype : public object {
Expand Down Expand Up @@ -328,6 +334,64 @@ class dtype : public object {
}
};

/** Class provide unsafe, unchecked const access to array data. This is constructed through the
* `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`.
*/
template <typename T, size_t Dims>
class unchecked_const_reference {
protected:
const unsigned char *data_;
// Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
// make large performance gains on big, nested loops.
std::array<size_t, Dims> shape_, strides_;

friend class array;
unchecked_const_reference(const void *data, const size_t *shape, const size_t *strides)
: data_{reinterpret_cast<const unsigned char *>(data)} {
for (size_t i = 0; i < Dims; i++) {
shape_[i] = shape[i];
strides_[i] = strides[i];
}
}

public:
/** Unchecked const reference access to data at the given indices. Omiting trailing indices
* is equivalent to specifying them as 0.
*/
template <typename... Ix> const T& operator()(Ix... index) const {
static_assert(sizeof...(Ix) <= Dims, "Invalid number of indices for unchecked array reference");
return *reinterpret_cast<const T *>(data_ + detail::byte_offset_unsafe(strides_, size_t{index}...));
}
/** Unchecked const reference access to data; this operator only participates if the reference
* is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
*/
template <typename = detail::enable_if_t<Dims == 1>>
const T &operator[](size_t index) const { return operator()(index); }

/// Returns the shape (i.e. size) of dimension `dim`
size_t shape(size_t dim) const { return shape_[dim]; }

/// Returns the number of dimensions of the array
constexpr static size_t ndim() { return Dims; }
};

template <typename T, size_t Dims>
class unchecked_reference : public unchecked_const_reference<T, Dims> {
friend class array;
using unchecked_const_reference<T, Dims>::unchecked_const_reference;
public:
/// Mutable, unchecked access to data at the given indices.
template <typename... Ix> T& operator()(Ix... index) {
static_assert(sizeof...(Ix) == Dims, "Invalid number of indices for unchecked array reference");
return const_cast<T &>(unchecked_const_reference<T, Dims>::operator()(index...));
}
/** Mutable, unchecked access data at the given index; this operator only participates if the
* reference is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
*/
template <typename = detail::enable_if_t<Dims == 1>>
T &operator[](size_t index) { return operator()(index); }
};

class array : public buffer {
public:
PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
Expand Down Expand Up @@ -500,6 +564,36 @@ class array : public buffer {
return offset_at(index...) / itemsize();
}

/** Returns a proxy object that provides access to the array's data without bounds or
* dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
* care: the array must not be destroyed or reshaped for the duration of the returned object,
* and the caller must take care not to access invalid dimensions or dimension indices.
*/
template <typename T, size_t Dims> unchecked_reference<T, Dims> unchecked() {
if (ndim() != Dims)
throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
"; expected " + std::to_string(Dims));
return unchecked_reference<T, Dims>(mutable_data(), shape(), strides());
}

/** Returns a proxy object that provides const access to the array's data without bounds or
* dimensionality checking. Unlike `unchecked()`, this does not require that the underlying
* array have the `writable` flag. Use with care: the array must not be destroyed or reshaped
* for the duration of the returned object, and the caller must take care not to access invalid
* dimensions or dimension indices.
*/
template <typename T, size_t Dims> unchecked_const_reference<T, Dims> unchecked_readonly() const {
if (ndim() != Dims)
throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
"; expected " + std::to_string(Dims));
return unchecked_const_reference<T, Dims>(data(), shape(), strides());
}

/// Equivalent to `unchecked_readonly()` (for `const array_t<T>` object)
template <typename T, size_t Dims> unchecked_const_reference<T, Dims> unchecked() const {
return unchecked_readonly<T, Dims>();
}

/// Return a new view with all of the dimensions of length 1 removed
array squeeze() {
auto& api = detail::npy_api::get();
Expand All @@ -525,15 +619,9 @@ class array : public buffer {

template<typename... Ix> size_t byte_offset(Ix... index) const {
check_dimensions(index...);
return byte_offset_unsafe(index...);
return detail::byte_offset_unsafe(strides(), size_t{index}...);
}

template<size_t dim = 0, typename... Ix> size_t byte_offset_unsafe(size_t i, Ix... index) const {
return i * strides()[dim] + byte_offset_unsafe<dim + 1>(index...);
}

template<size_t dim = 0> size_t byte_offset_unsafe() const { return 0; }

void check_writeable() const {
if (!writeable())
throw std::domain_error("array is not writeable");
Expand Down Expand Up @@ -637,6 +725,30 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
}

/** Returns a proxy object that provides access to the array's data without bounds or
* dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
* care: the array must not be destroyed or reshaped for the duration of the returned object,
* and the caller must take care not to access invalid dimensions or dimension indices.
*/
template <size_t Dims> unchecked_reference<T, Dims> unchecked() {
return array::unchecked<T, Dims>();
}

/** Returns a proxy object that provides const access to the array's data without bounds or
* dimensionality checking. Unlike `unchecked()`, this does not require that the underlying
* array have the `writable` flag. Use with care: the array must not be destroyed or reshaped
* for the duration of the returned object, and the caller must take care not to access invalid
* dimensions or dimension indices.
*/
template <size_t Dims> unchecked_const_reference<T, Dims> unchecked_readonly() const {
return array::unchecked_readonly<T, Dims>();
}

/// Equivalent to `unchecked_readonly()` (for `const array_t<T>` object)
template <size_t Dims> unchecked_const_reference<T, Dims> unchecked() const {
return unchecked_readonly<Dims>();
}

/// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
/// it). In case of an error, nullptr is returned and the Python error is cleared.
static array_t ensure(handle h) {
Expand Down