Skip to content

Commit 0d0ae51

Browse files
committed
override: Fix wrong caching of the overrides
There was a problem when the python type, which was stored in override cache for C++ functions, was destroyed and the record wasn't removed from the override cache. Therefor, dangling pointer was stored there. Then when the memory was reused and new type was allocated at the given address and the method with the same name (as previously stored in the cache) was actually overridden in python, it would wrongly find it in the override cache for C++ functions and therefor override from python wouldn't be called. The fix is to erase the type from the override cache when the type is destroyed.
1 parent 58c7f07 commit 0d0ae51

File tree

6 files changed

+124
-1
lines changed

6 files changed

+124
-1
lines changed

include/pybind11/pybind11.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2093,6 +2093,16 @@ inline std::pair<decltype(internals::registered_types_py)::iterator, bool> all_t
20932093
// gets destroyed:
20942094
weakref((PyObject *) type, cpp_function([type](handle wr) {
20952095
get_internals().registered_types_py.erase(type);
2096+
2097+
// Actually just `std::erase_if`, but that's only available in C++20
2098+
auto &cache = get_internals().inactive_override_cache;
2099+
for (auto it = cache.begin(), last = cache.end(); it != last; ) {
2100+
if (it->first == reinterpret_cast<PyObject *>(type))
2101+
it = cache.erase(it);
2102+
else
2103+
++it;
2104+
}
2105+
20962106
wr.dec_ref();
20972107
})).release();
20982108
}

tests/test_class_sh_inheritance.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,24 @@ struct drvd2 : base1, base2 {
4949
int id() const override { return 3 * base1::base_id + 4 * base2::base_id; }
5050
};
5151

52+
class TestDerived {
53+
54+
public:
55+
virtual int func() { return 0; }
56+
57+
TestDerived() = default;
58+
~TestDerived() = default;
59+
// Non-copyable
60+
TestDerived &operator=(TestDerived const &Right) = delete;
61+
TestDerived(TestDerived const &Copy) = delete;
62+
};
63+
64+
class PyTestDerived : public TestDerived {
65+
virtual int func() override { PYBIND11_OVERRIDE(int, TestDerived, func); }
66+
};
67+
68+
inline int test_override_cache(std::shared_ptr < TestDerived> instance) { return instance->func(); }
69+
5270
// clang-format off
5371
inline drvd2 *rtrn_mptr_drvd2() { return new drvd2; }
5472
inline base1 *rtrn_mptr_drvd2_up_cast1() { return new drvd2; }
@@ -69,6 +87,8 @@ PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_inheritance::base1)
6987
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_inheritance::base2)
7088
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_inheritance::drvd2)
7189

90+
PYBIND11_SMART_HOLDER_TYPE_CASTERS(pybind11_tests::class_sh_inheritance::TestDerived)
91+
7292
namespace pybind11_tests {
7393
namespace class_sh_inheritance {
7494

@@ -99,6 +119,12 @@ TEST_SUBMODULE(class_sh_inheritance, m) {
99119
m.def("pass_cptr_base1", pass_cptr_base1);
100120
m.def("pass_cptr_base2", pass_cptr_base2);
101121
m.def("pass_cptr_drvd2", pass_cptr_drvd2);
122+
123+
py::classh<TestDerived, PyTestDerived>(m, "TestDerived")
124+
.def(py::init_alias<>())
125+
.def("func", &TestDerived::func);
126+
127+
m.def("test_override_cache", test_override_cache);
102128
}
103129

104130
} // namespace class_sh_inheritance

tests/test_class_sh_inheritance.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,22 @@ def __init__(self):
6161
assert i1 == 110 + 21
6262
i2 = m.pass_cptr_base2(d)
6363
assert i2 == 120 + 22
64+
65+
66+
def test_python_override():
67+
def func():
68+
class Test(m.TestDerived):
69+
def func(self):
70+
return 42
71+
72+
return Test()
73+
74+
def func2():
75+
class Test(m.TestDerived):
76+
pass
77+
78+
return Test()
79+
80+
for i in range(1500):
81+
assert m.test_override_cache(func()) == 42
82+
assert m.test_override_cache(func2()) == 0

tests/test_embed/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pybind11_enable_warnings(test_embed)
2525
target_link_libraries(test_embed PRIVATE pybind11::embed Catch2::Catch2 Threads::Threads)
2626

2727
if(NOT CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_CURRENT_BINARY_DIR)
28-
file(COPY test_interpreter.py DESTINATION "${CMAKE_CURRENT_BINARY_DIR}")
28+
file(COPY test_interpreter.py test_derived.py DESTINATION "${CMAKE_CURRENT_BINARY_DIR}")
2929
endif()
3030

3131
add_custom_target(

tests/test_embed/test_derived.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# -*- coding: utf-8 -*-
2+
import sys
3+
4+
import derived_module
5+
6+
7+
def func():
8+
class Test(derived_module.TestDerived):
9+
def func(self):
10+
return 42
11+
12+
return Test()
13+
14+
15+
def func2():
16+
class Test(derived_module.TestDerived):
17+
pass
18+
19+
return Test()

tests/test_embed/test_interpreter.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <pybind11/embed.h>
2+
#include <pybind11/smart_holder.h>
23

34
#ifdef _MSC_VER
45
// Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to catch
@@ -37,6 +38,24 @@ class PyWidget final : public Widget {
3738
std::string argv0() const override { PYBIND11_OVERRIDE_PURE(std::string, Widget, argv0); }
3839
};
3940

41+
class TestDerived {
42+
43+
public:
44+
virtual int func() { return 0; }
45+
46+
TestDerived() = default;
47+
virtual ~TestDerived() = default;
48+
// Non-copyable
49+
TestDerived &operator=(TestDerived const &Right) = delete;
50+
TestDerived(TestDerived const &Copy) = delete;
51+
};
52+
53+
class PyTestDerived : public TestDerived {
54+
virtual int func() override { PYBIND11_OVERRIDE(int, TestDerived, func); }
55+
};
56+
57+
PYBIND11_SMART_HOLDER_TYPE_CASTERS(TestDerived)
58+
4059
PYBIND11_EMBEDDED_MODULE(widget_module, m) {
4160
py::class_<Widget, PyWidget>(m, "Widget")
4261
.def(py::init<std::string>())
@@ -45,6 +64,12 @@ PYBIND11_EMBEDDED_MODULE(widget_module, m) {
4564
m.def("add", [](int i, int j) { return i + j; });
4665
}
4766

67+
PYBIND11_EMBEDDED_MODULE(derived_module, m) {
68+
py::classh<TestDerived, PyTestDerived>(m, "TestDerived")
69+
.def(py::init_alias<>())
70+
.def("func", &TestDerived::func);
71+
}
72+
4873
PYBIND11_EMBEDDED_MODULE(throw_exception, ) {
4974
throw std::runtime_error("C++ Error");
5075
}
@@ -73,6 +98,30 @@ TEST_CASE("Pass classes and data between modules defined in C++ and Python") {
7398
REQUIRE(cpp_widget.the_answer() == 42);
7499
}
75100

101+
TEST_CASE("Override cache") {
102+
auto module_ = py::module_::import("test_derived");
103+
REQUIRE(py::hasattr(module_, "func"));
104+
REQUIRE(py::hasattr(module_, "func2"));
105+
106+
auto locals = py::dict(**module_.attr("__dict__"));
107+
108+
int i = 0;
109+
for (; i < 1500; ++i) {
110+
std::shared_ptr<TestDerived> p_obj;
111+
std::shared_ptr<TestDerived> p_obj2;
112+
113+
p_obj = pybind11::cast<std::shared_ptr<TestDerived>>(locals["func"]());
114+
115+
int ret = p_obj->func();
116+
117+
REQUIRE(ret == 42);
118+
119+
p_obj2 = pybind11::cast<std::shared_ptr<TestDerived>>(locals["func2"]());
120+
121+
p_obj2->func();
122+
}
123+
}
124+
76125
TEST_CASE("Import error handling") {
77126
REQUIRE_NOTHROW(py::module_::import("widget_module"));
78127
REQUIRE_THROWS_WITH(py::module_::import("throw_exception"),

0 commit comments

Comments
 (0)