Skip to content

Commit 891be80

Browse files
Mmanu ChaturvediEricCousineau-TRI
Mmanu Chaturvedi
authored andcommitted
Add ability to create object matrices
1 parent a303c6f commit 891be80

File tree

4 files changed

+251
-22
lines changed

4 files changed

+251
-22
lines changed

include/pybind11/eigen.h

Lines changed: 144 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ struct eigen_extract_stride<Eigen::Map<PlainObjectType, MapOptions, StrideType>>
112112
template <typename PlainObjectType, int Options, typename StrideType>
113113
struct eigen_extract_stride<Eigen::Ref<PlainObjectType, Options, StrideType>> { using type = StrideType; };
114114

115+
template <typename Scalar> bool is_pyobject_() {
116+
return static_cast<pybind11::detail::npy_api::constants>(npy_format_descriptor<Scalar>::value) == npy_api::NPY_OBJECT_;
117+
}
118+
115119
// Helper struct for extracting information from an Eigen type
116120
template <typename Type_> struct EigenProps {
117121
using Type = Type_;
@@ -144,14 +148,19 @@ template <typename Type_> struct EigenProps {
144148
const auto dims = a.ndim();
145149
if (dims < 1 || dims > 2)
146150
return false;
147-
151+
bool is_pyobject = false;
152+
if (is_pyobject_<Scalar>())
153+
is_pyobject = true;
154+
ssize_t scalar_size = (is_pyobject ? static_cast<ssize_t>(sizeof(PyObject*)) :
155+
static_cast<ssize_t>(sizeof(Scalar)));
148156
if (dims == 2) { // Matrix type: require exact match (or dynamic)
149157

150158
EigenIndex
151159
np_rows = a.shape(0),
152160
np_cols = a.shape(1),
153-
np_rstride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar)),
154-
np_cstride = a.strides(1) / static_cast<ssize_t>(sizeof(Scalar));
161+
np_rstride = a.strides(0) / scalar_size,
162+
np_cstride = a.strides(1) / scalar_size;
163+
155164
if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols))
156165
return false;
157166

@@ -161,7 +170,7 @@ template <typename Type_> struct EigenProps {
161170
// Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever
162171
// is used, we want the (single) numpy stride value.
163172
const EigenIndex n = a.shape(0),
164-
stride = a.strides(0) / static_cast<ssize_t>(sizeof(Scalar));
173+
stride = a.strides(0) / scalar_size;
165174

166175
if (vector) { // Eigen type is a compile-time vector
167176
if (fixed && size != n)
@@ -212,11 +221,51 @@ template <typename Type_> struct EigenProps {
212221
template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
213222
constexpr ssize_t elem_size = sizeof(typename props::Scalar);
214223
array a;
215-
if (props::vector)
216-
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
217-
else
218-
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
219-
src.data(), base);
224+
using Scalar = typename props::Type::Scalar;
225+
bool is_pyoject = static_cast<pybind11::detail::npy_api::constants>(npy_format_descriptor<Scalar>::value) == npy_api::NPY_OBJECT_;
226+
227+
if (!is_pyoject) {
228+
if (props::vector)
229+
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
230+
else
231+
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
232+
src.data(), base);
233+
}
234+
else {
235+
if (props::vector) {
236+
a = array(
237+
npy_format_descriptor<Scalar>::dtype(),
238+
{ (size_t) src.size() },
239+
nullptr,
240+
base
241+
);
242+
auto policy = base ? return_value_policy::automatic_reference : return_value_policy::copy;
243+
for (ssize_t i = 0; i < src.size(); ++i) {
244+
const Scalar src_val = props::fixed_rows ? src(0, i) : src(i, 0);
245+
auto value_ = reinterpret_steal<object>(make_caster<Scalar>::cast(src_val, policy, base));
246+
if (!value_)
247+
return handle();
248+
a.attr("itemset")(i, value_);
249+
}
250+
}
251+
else {
252+
a = array(
253+
npy_format_descriptor<Scalar>::dtype(),
254+
{(size_t) src.rows(), (size_t) src.cols()},
255+
nullptr,
256+
base
257+
);
258+
auto policy = base ? return_value_policy::automatic_reference : return_value_policy::copy;
259+
for (ssize_t i = 0; i < src.rows(); ++i) {
260+
for (ssize_t j = 0; j < src.cols(); ++j) {
261+
auto value_ = reinterpret_steal<object>(make_caster<Scalar>::cast(src(i, j), policy, base));
262+
if (!value_)
263+
return handle();
264+
a.attr("itemset")(i, j, value_);
265+
}
266+
}
267+
}
268+
}
220269

221270
if (!writeable)
222271
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
@@ -270,14 +319,46 @@ struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
270319
auto fits = props::conformable(buf);
271320
if (!fits)
272321
return false;
273-
322+
int result = 0;
274323
// Allocate the new type, then build a numpy reference into it
275324
value = Type(fits.rows, fits.cols);
276-
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
277-
if (dims == 1) ref = ref.squeeze();
278-
else if (ref.ndim() == 1) buf = buf.squeeze();
279-
280-
int result = detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
325+
bool is_pyobject = is_pyobject_<Scalar>();
326+
327+
if (!is_pyobject) {
328+
auto ref = reinterpret_steal<array>(eigen_ref_array<props>(value));
329+
if (dims == 1) ref = ref.squeeze();
330+
else if (ref.ndim() == 1) buf = buf.squeeze();
331+
result =
332+
detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr());
333+
}
334+
else {
335+
if (dims == 1) {
336+
if (Type::RowsAtCompileTime == Eigen::Dynamic)
337+
value.resize(buf.shape(0), 1);
338+
if (Type::ColsAtCompileTime == Eigen::Dynamic)
339+
value.resize(1, buf.shape(0));
340+
341+
for (ssize_t i = 0; i < buf.shape(0); ++i) {
342+
make_caster <Scalar> conv_val;
343+
if (!conv_val.load(buf.attr("item")(i).cast<pybind11::object>(), convert))
344+
return false;
345+
value(i) = cast_op<Scalar>(conv_val);
346+
}
347+
} else {
348+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
349+
value.resize(buf.shape(0), buf.shape(1));
350+
}
351+
for (ssize_t i = 0; i < buf.shape(0); ++i) {
352+
for (ssize_t j = 0; j < buf.shape(1); ++j) {
353+
// p is the const void pointer to the item
354+
make_caster<Scalar> conv_val;
355+
if (!conv_val.load(buf.attr("item")(i,j).cast<pybind11::object>(), convert))
356+
return false;
357+
value(i,j) = cast_op<Scalar>(conv_val);
358+
}
359+
}
360+
}
361+
}
281362

282363
if (result < 0) { // Copy failed!
283364
PyErr_Clear();
@@ -429,13 +510,19 @@ struct type_caster<
429510
// storage order conversion. (Note that we refuse to use this temporary copy when loading an
430511
// argument for a Ref<M> with M non-const, i.e. a read-write reference).
431512
Array copy_or_ref;
513+
typename std::remove_cv<PlainObjectType>::type val;
432514
public:
433515
bool load(handle src, bool convert) {
434516
// First check whether what we have is already an array of the right type. If not, we can't
435517
// avoid a copy (because the copy is also going to do type conversion).
436518
bool need_copy = !isinstance<Array>(src);
437519

438520
EigenConformable<props::row_major> fits;
521+
bool is_pyobject = false;
522+
if (is_pyobject_<Scalar>()) {
523+
is_pyobject = true;
524+
need_copy = true;
525+
}
439526
if (!need_copy) {
440527
// We don't need a converting copy, but we also need to check whether the strides are
441528
// compatible with the Ref's stride requirements
@@ -458,15 +545,53 @@ struct type_caster<
458545
// We need to copy: If we need a mutable reference, or we're not supposed to convert
459546
// (either because we're in the no-convert overload pass, or because we're explicitly
460547
// instructed not to copy (via `py::arg().noconvert()`) we have to fail loading.
461-
if (!convert || need_writeable) return false;
548+
if (!is_pyobject && (!convert || need_writeable)) {
549+
return false;
550+
}
462551

463552
Array copy = Array::ensure(src);
464553
if (!copy) return false;
465554
fits = props::conformable(copy);
466-
if (!fits || !fits.template stride_compatible<props>())
555+
if (!fits || !fits.template stride_compatible<props>()) {
467556
return false;
468-
copy_or_ref = std::move(copy);
469-
loader_life_support::add_patient(copy_or_ref);
557+
}
558+
559+
if (!is_pyobject) {
560+
copy_or_ref = std::move(copy);
561+
loader_life_support::add_patient(copy_or_ref);
562+
}
563+
else {
564+
auto dims = copy.ndim();
565+
if (dims == 1) {
566+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
567+
val.resize(copy.shape(0), 1);
568+
}
569+
for (ssize_t i = 0; i < copy.shape(0); ++i) {
570+
make_caster <Scalar> conv_val;
571+
if (!conv_val.load(copy.attr("item")(i).template cast<pybind11::object>(),
572+
convert))
573+
return false;
574+
val(i) = cast_op<Scalar>(conv_val);
575+
576+
}
577+
} else {
578+
if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) {
579+
val.resize(copy.shape(0), copy.shape(1));
580+
}
581+
for (ssize_t i = 0; i < copy.shape(0); ++i) {
582+
for (ssize_t j = 0; j < copy.shape(1); ++j) {
583+
// p is the const void pointer to the item
584+
make_caster <Scalar> conv_val;
585+
if (!conv_val.load(copy.attr("item")(i, j).template cast<pybind11::object>(),
586+
convert))
587+
return false;
588+
val(i, j) = cast_op<Scalar>(conv_val);
589+
}
590+
}
591+
}
592+
ref.reset(new Type(val));
593+
return true;
594+
}
470595
}
471596

472597
ref.reset();

include/pybind11/numpy.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,21 @@ template <typename T, typename SFINAE> struct npy_format_descriptor {
12271227
::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
12281228
({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
12291229

1230+
#define PYBIND11_NUMPY_OBJECT_DTYPE(Type) \
1231+
namespace pybind11 { namespace detail { \
1232+
template <> struct npy_format_descriptor<Type> { \
1233+
public: \
1234+
enum { value = npy_api::NPY_OBJECT_ }; \
1235+
static pybind11::dtype dtype() { \
1236+
if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) { \
1237+
return reinterpret_borrow<pybind11::dtype>(ptr); \
1238+
} \
1239+
pybind11_fail("Unsupported buffer format!"); \
1240+
} \
1241+
static constexpr auto name = _("object"); \
1242+
}; \
1243+
}}
1244+
12301245
#endif // __CLION_IDE__
12311246

12321247
template <class T>

tests/test_eigen.cpp

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@
1212
#include <pybind11/eigen.h>
1313
#include <pybind11/stl.h>
1414
#include <Eigen/Cholesky>
15+
#include <unsupported/Eigen/AutoDiff>
16+
#include "Eigen/src/Core/util/DisableStupidWarnings.h"
1517

1618
using MatrixXdR = Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
17-
18-
19+
typedef Eigen::AutoDiffScalar<Eigen::VectorXd> ADScalar;
20+
typedef Eigen::Matrix<ADScalar, Eigen::Dynamic, 1> VectorXADScalar;
21+
typedef Eigen::Matrix<ADScalar, 1, Eigen::Dynamic> VectorXADScalarR;
22+
PYBIND11_NUMPY_OBJECT_DTYPE(ADScalar);
1923

2024
// Sets/resets a testing reference matrix to have values of 10*r + c, where r and c are the
2125
// (1-based) row/column number.
@@ -74,7 +78,9 @@ TEST_SUBMODULE(eigen, m) {
7478
using FixedMatrixR = Eigen::Matrix<float, 5, 6, Eigen::RowMajor>;
7579
using FixedMatrixC = Eigen::Matrix<float, 5, 6>;
7680
using DenseMatrixR = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
81+
using DenseADScalarMatrixR = Eigen::Matrix<ADScalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
7782
using DenseMatrixC = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>;
83+
using DenseADScalarMatrixC = Eigen::Matrix<ADScalar, Eigen::Dynamic, Eigen::Dynamic>;
7884
using FourRowMatrixC = Eigen::Matrix<float, 4, Eigen::Dynamic>;
7985
using FourColMatrixC = Eigen::Matrix<float, Eigen::Dynamic, 4>;
8086
using FourRowMatrixR = Eigen::Matrix<float, 4, Eigen::Dynamic>;
@@ -86,10 +92,14 @@ TEST_SUBMODULE(eigen, m) {
8692

8793
// various tests
8894
m.def("double_col", [](const Eigen::VectorXf &x) -> Eigen::VectorXf { return 2.0f * x; });
95+
m.def("double_adscalar_col", [](const VectorXADScalar &x) -> VectorXADScalar { return 2.0f * x; });
8996
m.def("double_row", [](const Eigen::RowVectorXf &x) -> Eigen::RowVectorXf { return 2.0f * x; });
97+
m.def("double_adscalar_row", [](const VectorXADScalarR &x) -> VectorXADScalarR { return 2.0f * x; });
9098
m.def("double_complex", [](const Eigen::VectorXcf &x) -> Eigen::VectorXcf { return 2.0f * x; });
9199
m.def("double_threec", [](py::EigenDRef<Eigen::Vector3f> x) { x *= 2; });
100+
m.def("double_adscalarc", [](py::EigenDRef<VectorXADScalar> x) { x *= 2; });
92101
m.def("double_threer", [](py::EigenDRef<Eigen::RowVector3f> x) { x *= 2; });
102+
m.def("double_adscalarr", [](py::EigenDRef<VectorXADScalarR> x) { x *= 2; });
93103
m.def("double_mat_cm", [](Eigen::MatrixXf x) -> Eigen::MatrixXf { return 2.0f * x; });
94104
m.def("double_mat_rm", [](DenseMatrixR x) -> DenseMatrixR { return 2.0f * x; });
95105

@@ -134,6 +144,12 @@ TEST_SUBMODULE(eigen, m) {
134144
return m;
135145
}, py::return_value_policy::reference);
136146

147+
// Increments ADScalar Matrix
148+
m.def("incr_adscalar_matrix", [](Eigen::Ref<DenseADScalarMatrixC> m, double v) {
149+
m += DenseADScalarMatrixC::Constant(m.rows(), m.cols(), v);
150+
return m;
151+
}, py::return_value_policy::reference);
152+
137153
// Same, but accepts a matrix of any strides
138154
m.def("incr_matrix_any", [](py::EigenDRef<Eigen::MatrixXd> m, double v) {
139155
m += Eigen::MatrixXd::Constant(m.rows(), m.cols(), v);
@@ -168,12 +184,16 @@ TEST_SUBMODULE(eigen, m) {
168184
// return value referencing/copying tests:
169185
class ReturnTester {
170186
Eigen::MatrixXd mat = create();
187+
DenseADScalarMatrixR ad_mat = create_ADScalar_mat();
171188
public:
172189
ReturnTester() { print_created(this); }
173190
~ReturnTester() { print_destroyed(this); }
174-
static Eigen::MatrixXd create() { return Eigen::MatrixXd::Ones(10, 10); }
191+
static Eigen::MatrixXd create() { return Eigen::MatrixXd::Ones(10, 10); }
192+
static DenseADScalarMatrixR create_ADScalar_mat() { DenseADScalarMatrixR ad_mat(2, 2);
193+
ad_mat << 1, 2, 3, 7; return ad_mat; }
175194
static const Eigen::MatrixXd createConst() { return Eigen::MatrixXd::Ones(10, 10); }
176195
Eigen::MatrixXd &get() { return mat; }
196+
DenseADScalarMatrixR& get_ADScalarMat() {return ad_mat;}
177197
Eigen::MatrixXd *getPtr() { return &mat; }
178198
const Eigen::MatrixXd &view() { return mat; }
179199
const Eigen::MatrixXd *viewPtr() { return &mat; }
@@ -192,6 +212,7 @@ TEST_SUBMODULE(eigen, m) {
192212
.def_static("create", &ReturnTester::create)
193213
.def_static("create_const", &ReturnTester::createConst)
194214
.def("get", &ReturnTester::get, rvp::reference_internal)
215+
.def("get_ADScalarMat", &ReturnTester::get_ADScalarMat, rvp::reference_internal)
195216
.def("get_ptr", &ReturnTester::getPtr, rvp::reference_internal)
196217
.def("view", &ReturnTester::view, rvp::reference_internal)
197218
.def("view_ptr", &ReturnTester::view, rvp::reference_internal)
@@ -211,6 +232,18 @@ TEST_SUBMODULE(eigen, m) {
211232
.def("corners_const", &ReturnTester::cornersConst, rvp::reference_internal)
212233
;
213234

235+
py::class_<ADScalar>(m, "AutoDiffXd")
236+
.def("__init__",
237+
[](ADScalar & self,
238+
double value,
239+
const Eigen::VectorXd& derivatives) {
240+
new (&self) ADScalar(value, derivatives);
241+
})
242+
.def("value", [](const ADScalar & self) {
243+
return self.value();
244+
})
245+
;
246+
214247
// test_special_matrix_objects
215248
// Returns a DiagonalMatrix with diagonal (1,2,3,...)
216249
m.def("incr_diag", [](int k) {
@@ -295,6 +328,9 @@ TEST_SUBMODULE(eigen, m) {
295328
m.def("iss1105_col", [](Eigen::VectorXd) { return true; });
296329
m.def("iss1105_row", [](Eigen::RowVectorXd) { return true; });
297330

331+
m.def("iss1105_col_obj", [](VectorXADScalar) { return true; });
332+
m.def("iss1105_row_obj", [](VectorXADScalarR) { return true; });
333+
298334
// test_named_arguments
299335
// Make sure named arguments are working properly:
300336
m.def("matrix_multiply", [](const py::EigenDRef<const Eigen::MatrixXd> A, const py::EigenDRef<const Eigen::MatrixXd> B)

0 commit comments

Comments
 (0)