Skip to content

Commit 627da3f

Browse files
Cris Luengodean0x7d
Cris Luengo
authored andcommitted
Making a copy when casting a numpy array with negative strides to Eigen.
`EigenConformable::stride_compatible` returns false if the strides are negative. In this case, do not use `EigenConformable::stride`, as it is {0,0}. We cannot write negative strides in this element, as Eigen will throw an assertion if we do. The `type_caster` specialization for regular, dense Eigen matrices now does a second `array_t::ensure` to copy data in case of negative strides. I'm not sure that this is the best way to implement this. I have added "TODO" tags linking these changes to Eigen bug #747, which, when fixed, will allow Eigen to accept negative strides.
1 parent d400f60 commit 627da3f

File tree

4 files changed

+93
-23
lines changed

4 files changed

+93
-23
lines changed

include/pybind11/eigen.h

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,22 @@ template <typename T> using is_eigen_other = all_of<
6868
template <bool EigenRowMajor> struct EigenConformable {
6969
bool conformable = false;
7070
EigenIndex rows = 0, cols = 0;
71-
EigenDStride stride{0, 0};
71+
EigenDStride stride{0, 0}; // Only valid if negativestridees is false!
72+
bool negativestrides = false; // If true, do not use stride!
7273

7374
EigenConformable(bool fits = false) : conformable{fits} {}
7475
// Matrix type:
7576
EigenConformable(EigenIndex r, EigenIndex c,
7677
EigenIndex rstride, EigenIndex cstride) :
77-
conformable{true}, rows{r}, cols{c},
78-
stride(EigenRowMajor ? rstride : cstride /* outer stride */,
79-
EigenRowMajor ? cstride : rstride /* inner stride */)
80-
{}
78+
conformable{true}, rows{r}, cols{c} {
79+
// TODO: when Eigen bug #747 is fixed, remove the tests for non-negativity. http://eigen.tuxfamily.org/bz/show_bug.cgi?id=747
80+
if (rstride < 0 || cstride < 0) {
81+
negativestrides = true;
82+
} else {
83+
stride = {EigenRowMajor ? rstride : cstride /* outer stride */,
84+
EigenRowMajor ? cstride : rstride /* inner stride */ };
85+
}
86+
}
8187
// Vector type:
8288
EigenConformable(EigenIndex r, EigenIndex c, EigenIndex stride)
8389
: EigenConformable(r, c, r == 1 ? c*stride : stride, c == 1 ? r : r*stride) {}
@@ -86,6 +92,7 @@ template <bool EigenRowMajor> struct EigenConformable {
8692
// To have compatible strides, we need (on both dimensions) one of fully dynamic strides,
8793
// matching strides, or a dimension size of 1 (in which case the stride value is irrelevant)
8894
return
95+
!negativestrides &&
8996
(props::inner_stride == Eigen::Dynamic || props::inner_stride == stride.inner() ||
9097
(EigenRowMajor ? cols : rows) == 1) &&
9198
(props::outer_stride == Eigen::Dynamic || props::outer_stride == stride.outer() ||
@@ -138,8 +145,8 @@ template <typename Type_> struct EigenProps {
138145
EigenIndex
139146
np_rows = a.shape(0),
140147
np_cols = a.shape(1),
141-
np_rstride = a.strides(0) / sizeof(Scalar),
142-
np_cstride = a.strides(1) / sizeof(Scalar);
148+
np_rstride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar)),
149+
np_cstride = a.strides(1) / static_cast<ssize_t>(sizeof(Scalar));
143150
if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols))
144151
return false;
145152

@@ -149,7 +156,7 @@ template <typename Type_> struct EigenProps {
149156
// Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever
150157
// is used, we want the (single) numpy stride value.
151158
const EigenIndex n = a.shape(0),
152-
stride = a.strides(0) / sizeof(Scalar);
159+
stride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar));
153160

154161
if (vector) { // Eigen type is a compile-time vector
155162
if (fixed && size != n)
@@ -255,7 +262,23 @@ struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
255262
if (!fits)
256263
return false; // Non-comformable vector/matrix types
257264

258-
value = Eigen::Map<const Type, 0, EigenDStride>(buf.data(), fits.rows, fits.cols, fits.stride);
265+
if (fits.negativestrides) {
266+
267+
// Eigen does not support negative strides, so we need to make a copy here with normal strides.
268+
// TODO: when Eigen bug #747 is fixed, remove this if case, always execute the else part.
269+
// http://eigen.tuxfamily.org/bz/show_bug.cgi?id=747
270+
auto buf2 = array_t<Scalar,array::forcecast || array::f_style>::ensure(src);
271+
if (!buf2)
272+
return false;
273+
// not checking sizes, we already did that
274+
fits = props::conformable(buf2);
275+
value = Eigen::Map<const Type, 0, EigenDStride>(buf2.data(), fits.rows, fits.cols, fits.stride);
276+
277+
} else {
278+
279+
value = Eigen::Map<const Type, 0, EigenDStride>(buf.data(), fits.rows, fits.cols, fits.stride);
280+
281+
}
259282

260283
return true;
261284
}

include/pybind11/numpy.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,10 @@ template <typename T> using is_pod_struct = all_of<
248248
satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
249249
>;
250250

251-
template <size_t Dim = 0, typename Strides> size_t byte_offset_unsafe(const Strides &) { return 0; }
251+
template <size_t Dim = 0, typename Strides> ssize_t byte_offset_unsafe(const Strides &) { return 0; }
252252
template <size_t Dim = 0, typename Strides, typename... Ix>
253-
size_t byte_offset_unsafe(const Strides &strides, size_t i, Ix... index) {
254-
return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
253+
ssize_t byte_offset_unsafe(const Strides &strides, size_t i, Ix... index) {
254+
return static_cast<ssize_t>(i) * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
255255
}
256256

257257
/** Proxy class providing unsafe, unchecked const access to array data. This is constructed through
@@ -615,18 +615,18 @@ class array : public buffer {
615615

616616
/// Byte offset from beginning of the array to a given index (full or partial).
617617
/// May throw if the index would lead to out of bounds access.
618-
template<typename... Ix> size_t offset_at(Ix... index) const {
618+
template<typename... Ix> ssize_t offset_at(Ix... index) const {
619619
if (sizeof...(index) > ndim())
620620
fail_dim_check(sizeof...(index), "too many indices for an array");
621621
return byte_offset(size_t(index)...);
622622
}
623623

624-
size_t offset_at() const { return 0; }
624+
ssize_t offset_at() const { return 0; }
625625

626626
/// Item count from beginning of the array to a given index (full or partial).
627627
/// May throw if the index would lead to out of bounds access.
628-
template<typename... Ix> size_t index_at(Ix... index) const {
629-
return offset_at(index...) / itemsize();
628+
template<typename... Ix> ssize_t index_at(Ix... index) const {
629+
return offset_at(index...) / static_cast<ssize_t>(itemsize());
630630
}
631631

632632
/** Returns a proxy object that provides access to the array's data without bounds or
@@ -692,7 +692,7 @@ class array : public buffer {
692692
" (ndim = " + std::to_string(ndim()) + ")");
693693
}
694694

695-
template<typename... Ix> size_t byte_offset(Ix... index) const {
695+
template<typename... Ix> ssize_t byte_offset(Ix... index) const {
696696
check_dimensions(index...);
697697
return detail::byte_offset_unsafe(strides(), size_t(index)...);
698698
}
@@ -773,8 +773,8 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
773773
return sizeof(T);
774774
}
775775

776-
template<typename... Ix> size_t index_at(Ix... index) const {
777-
return offset_at(index...) / itemsize();
776+
template<typename... Ix> ssize_t index_at(Ix... index) const {
777+
return offset_at(index...) / static_cast<ssize_t>(itemsize());
778778
}
779779

780780
template<typename... Ix> const T* data(Ix... index) const {
@@ -789,14 +789,14 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
789789
template<typename... Ix> const T& at(Ix... index) const {
790790
if (sizeof...(index) != ndim())
791791
fail_dim_check(sizeof...(index), "index dimension mismatch");
792-
return *(static_cast<const T*>(array::data()) + byte_offset(size_t(index)...) / itemsize());
792+
return *(static_cast<const T*>(array::data()) + byte_offset(size_t(index)...) / static_cast<ssize_t>(itemsize()));
793793
}
794794

795795
// Mutable reference to element at a given index
796796
template<typename... Ix> T& mutable_at(Ix... index) {
797797
if (sizeof...(index) != ndim())
798798
fail_dim_check(sizeof...(index), "index dimension mismatch");
799-
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
799+
return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / static_cast<ssize_t>(itemsize()));
800800
}
801801

802802
/** Returns a proxy object that provides access to the array's data without bounds or

include/pybind11/stl_bind.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,14 +350,14 @@ vector_buffer(Class_& cl) {
350350

351351
cl.def("__init__", [](Vector& vec, buffer buf) {
352352
auto info = buf.request();
353-
if (info.ndim != 1 || info.strides[0] <= 0 || info.strides[0] % sizeof(T))
353+
if (info.ndim != 1 || info.strides[0] <= 0 || info.strides[0] % static_cast<ssize_t>(sizeof(T)))
354354
throw type_error("Only valid 1D buffers can be copied to a vector");
355355
if (!detail::compare_buffer_info<T>::compare(info) || sizeof(T) != info.itemsize)
356356
throw type_error("Format mismatch (Python: " + info.format + " C++: " + format_descriptor<T>::format() + ")");
357357
new (&vec) Vector();
358358
vec.reserve(info.shape[0]);
359359
T *p = static_cast<T*>(info.ptr);
360-
auto step = info.strides[0] / sizeof(T);
360+
auto step = info.strides[0] / static_cast<ssize_t>(sizeof(T));
361361
T *end = p + info.shape[0] * step;
362362
for (; p < end; p += step)
363363
vec.push_back(*p);

tests/test_eigen.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,53 @@ def test_nonunit_stride_from_python():
154154
np.testing.assert_array_equal(counting_mat, [[0., 2, 2], [6, 16, 10], [6, 14, 8]])
155155

156156

157+
def test_negative_stride_from_python(msg):
158+
from pybind11_tests import (
159+
double_row, double_col, double_complex, double_mat_cm, double_mat_rm,
160+
double_threec, double_threer)
161+
162+
# Eigen doesn't support (as of yet) negative strides. When a function takes an Eigen
163+
# matrix by copy or const reference, we can pass a numpy array that has negative strides.
164+
# Otherwise, an exception will be thrown as Eigen will not be able to map the numpy array.
165+
166+
counting_mat = np.arange(9.0, dtype=np.float32).reshape((3, 3))
167+
counting_mat = counting_mat[::-1, ::-1]
168+
second_row = counting_mat[1, :]
169+
second_col = counting_mat[:, 1]
170+
np.testing.assert_array_equal(double_row(second_row), 2.0 * second_row)
171+
np.testing.assert_array_equal(double_col(second_row), 2.0 * second_row)
172+
np.testing.assert_array_equal(double_complex(second_row), 2.0 * second_row)
173+
np.testing.assert_array_equal(double_row(second_col), 2.0 * second_col)
174+
np.testing.assert_array_equal(double_col(second_col), 2.0 * second_col)
175+
np.testing.assert_array_equal(double_complex(second_col), 2.0 * second_col)
176+
177+
counting_3d = np.arange(27.0, dtype=np.float32).reshape((3, 3, 3))
178+
counting_3d = counting_3d[::-1, ::-1, ::-1]
179+
slices = [counting_3d[0, :, :], counting_3d[:, 0, :], counting_3d[:, :, 0]]
180+
for slice_idx, ref_mat in enumerate(slices):
181+
np.testing.assert_array_equal(double_mat_cm(ref_mat), 2.0 * ref_mat)
182+
np.testing.assert_array_equal(double_mat_rm(ref_mat), 2.0 * ref_mat)
183+
184+
# Mutator:
185+
with pytest.raises(TypeError) as excinfo:
186+
double_threer(second_row)
187+
assert msg(excinfo.value) == """
188+
double_threer(): incompatible function arguments. The following argument types are supported:
189+
1. (numpy.ndarray[float32[1, 3], flags.writeable]) -> arg0: None
190+
191+
Invoked with: array([ 5., 4., 3.], dtype=float32)
192+
"""
193+
194+
with pytest.raises(TypeError) as excinfo:
195+
double_threec(second_col)
196+
assert msg(excinfo.value) == """
197+
double_threec(): incompatible function arguments. The following argument types are supported:
198+
1. (numpy.ndarray[float32[3, 1], flags.writeable]) -> arg0: None
199+
200+
Invoked with: array([ 7., 4., 1.], dtype=float32)
201+
"""
202+
203+
157204
def test_nonunit_stride_to_python():
158205
from pybind11_tests import diagonal, diagonal_1, diagonal_n, block
159206

0 commit comments

Comments
 (0)