Skip to content

Commit 8d982e5

Browse files
committed
Fix lifetime issue for BaseMatrix derived in Python
Using workaround from here: pybind/pybind11#1546
1 parent c954827 commit 8d982e5

File tree

2 files changed

+73
-1
lines changed

2 files changed

+73
-1
lines changed

linalg/python_linalg.cpp

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,57 @@ using namespace ngla;
77
#include <myadt.hpp>
88

99

10+
// Workaround to ensure same lifetime for C++ and Python objects
11+
// see https://github.com/pybind/pybind11/issues/1546
12+
namespace pybind11::detail {
13+
14+
template<>
15+
struct type_caster<std::shared_ptr<BaseMatrix>>
16+
{
17+
PYBIND11_TYPE_CASTER (std::shared_ptr<BaseMatrix>, _("BaseMatrix"));
18+
19+
using BaseCaster = copyable_holder_caster<BaseMatrix, std::shared_ptr<BaseMatrix>>;
20+
21+
bool load (pybind11::handle src, bool b)
22+
{
23+
BaseCaster bc;
24+
bool success = bc.load (src, b);
25+
if (!success)
26+
{
27+
return false;
28+
}
29+
30+
auto py_obj = py::reinterpret_borrow<py::object> (src);
31+
auto base_ptr = static_cast<std::shared_ptr<BaseMatrix>> (bc);
32+
33+
// Construct a shared_ptr to the py::object
34+
auto py_obj_ptr = std::shared_ptr<object>{
35+
new object{py_obj},
36+
[](auto py_object_ptr) {
37+
// It's possible that when the shared_ptr dies we won't have the
38+
// gil (if the last holder is in a non-Python thread), so we
39+
// make sure to acquire it in the deleter.
40+
gil_scoped_acquire gil;
41+
delete py_object_ptr;
42+
}
43+
};
44+
45+
value = std::shared_ptr<BaseMatrix> (py_obj_ptr, base_ptr.get ());
46+
return true;
47+
}
48+
49+
static handle cast (std::shared_ptr<BaseMatrix> base,
50+
return_value_policy rvp,
51+
handle h)
52+
{
53+
return BaseCaster::cast (base, rvp, h);
54+
}
55+
};
56+
57+
template <>
58+
struct is_holder_type<BaseMatrix, std::shared_ptr<BaseMatrix>> : std::true_type {};
59+
}
60+
1061
template<typename T>
1162
void ExportSparseMatrix(py::module m)
1263
{
@@ -570,7 +621,7 @@ void NGS_DLL_HEADER ExportNgla(py::module &m) {
570621
new (instance) BaseMatrixTrampoline(); }
571622
)
572623
*/
573-
.def(py::init<> ([]() { return new BaseMatrixTrampoline(); }))
624+
.def(py::init<> ())
574625
.def("__str__", [](BaseMatrix &self) { return ToString<BaseMatrix>(self); } )
575626
.def_property_readonly("height", [] ( BaseMatrix & self)
576627
{ return self.Height(); }, "Height of the matrix" )

tests/pytest/test_basematrix.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ def MultTransAdd(self, s, x, y):
1515
for i in range(len(x)):
1616
y[i] = s*sum(x[i:])
1717

18+
def CreateColVector(self):
19+
return CreateVVector(5)
20+
1821
def test_derive_basematrix():
1922
m = MyMatrix()
2023

@@ -36,3 +39,21 @@ def test_derive_basematrix():
3639

3740

3841

42+
def test_derive_basematrix_lifetime():
43+
m = MyMatrix()@MyMatrix()
44+
m1 = MyMatrix()
45+
46+
x = CreateVVector(5)
47+
y = CreateVVector(5)
48+
y1 = CreateVVector(5)
49+
50+
for i in range(len(x)):
51+
x[i]=1+i
52+
53+
y.data = m*x
54+
y1.data = (m1@m1)*x
55+
assert list(y) == list(y1)
56+
57+
58+
59+

0 commit comments

Comments
 (0)