Skip to content

Commit 54f8341

Browse files
authored
Change implementation of the __init__() must be called when overriding __init__ safety feature to work for any metaclass. (#30095)
* Also wrap with `py::metaclass((PyObject *) &PyType_Type)` * Transfer additional tests from PyCLIF python_multiple_inheritance_test.py * Expand tests to fully cover wrapping with alternative metaclasses. * * Factor out `ensure_base_init_functions_were_called()`. * Call from new `tp_init_intercepted()` (adopting mechanism first added in PyCLIF: google/clif@7cba87d). * Remove `pybind11_meta_call()` (which was added with pybind/pybind11#2152). * Bug fix (maybe actually two bugs?): simplify condition to `type->tp_init != tp_init_intercepted` * Removing `Py_DECREF(self)` that leads to MSAN failure (Google toolchain). ``` ==6380==WARNING: MemorySanitizer: use-of-uninitialized-value #0 0x5611589c9a58 in Py_DECREF third_party/python_runtime/v3_11/Include/object.h:537:9 ... Uninitialized value was created by a heap deallocation #0 0x5611552757b0 in free third_party/llvm/llvm-project/compiler-rt/lib/msan/msan_interceptors.cpp:218:3 #1 0x56115898e06b in _PyMem_RawFree third_party/python_runtime/v3_11/Objects/obmalloc.c:154:5 #2 0x56115898f6ad in PyObject_Free third_party/python_runtime/v3_11/Objects/obmalloc.c:769:5 #3 0x561158271bcc in PyObject_GC_Del third_party/python_runtime/v3_11/Modules/gcmodule.c:2407:5 #4 0x7f21224b070c in pybind11_object_dealloc third_party/pybind11/include/pybind11/detail/class.h:483:5 #5 0x5611589c2ed0 in subtype_dealloc third_party/python_runtime/v3_11/Objects/typeobject.c:1463:5 ... ``` * IncludeCleaner fixes (Google toolchain). * Restore `type->tp_call = pybind11_meta_call;` for PyPy only. * pytest.skip("ensure_base_init_functions_were_called() does not work with PyPy and Python `type` as metaclass") * Do not intercept our own `tp_init` function (`pybind11_object_init`). * Add `derived_tp_init_registry` weakref-based cleanup. * Replace `assert()` with `if` to resolve erroneous `lambda capture 'type' is not used` diagnostics (many CI jobs; seems to be a clang issue). * Add `derived_tp_init_registry()->count(type) == 0` condition. * Changes based on feedback from @rainwoodman * Use PYBIND11_INIT_SAFETY_CHECKS_VIA_* macros, based on suggestion from @rainwoodman
1 parent 80c9ee6 commit 54f8341

File tree

4 files changed

+234
-38
lines changed

4 files changed

+234
-38
lines changed

include/pybind11/detail/class.h

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
#include "../attr.h"
1313
#include "../options.h"
1414

15+
#include <cassert>
16+
#include <unordered_map>
17+
1518
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
1619
PYBIND11_NAMESPACE_BEGIN(detail)
1720

@@ -179,6 +182,36 @@ extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name
179182
return PyType_Type.tp_getattro(obj, name);
180183
}
181184

185+
// Ensure that the base __init__ function(s) were called.
186+
// Set TypeError and return false if not.
187+
// CALLER IS RESPONSIBLE for managing the self refcount appropriately.
188+
inline bool ensure_base_init_functions_were_called(PyObject *self) {
189+
values_and_holders vhs(self);
190+
for (const auto &vh : vhs) {
191+
if (!vh.holder_constructed() && !vhs.is_redundant_value_and_holder(vh)) {
192+
PyErr_Format(PyExc_TypeError,
193+
"%.200s.__init__() must be called when overriding __init__",
194+
get_fully_qualified_tp_name(vh.type->type).c_str());
195+
return false;
196+
}
197+
}
198+
return true;
199+
}
200+
201+
// See google/pywrapcc#30095 for background.
202+
#if !defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT) \
203+
&& !defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS)
204+
# if !defined(PYPY_VERSION)
205+
// With CPython the safety checks work for any metaclass.
206+
// However, with PyPy this implementation does not work at all.
207+
# define PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT
208+
# else
209+
// With this the safety checks work only for the default `py::metaclass()`.
210+
# define PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS
211+
# endif
212+
#endif
213+
214+
#if defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS)
182215
/// metaclass `__call__` function that is used to create all pybind11 objects.
183216
extern "C" inline PyObject *pybind11_meta_call(PyObject *type, PyObject *args, PyObject *kwargs) {
184217

@@ -188,20 +221,14 @@ extern "C" inline PyObject *pybind11_meta_call(PyObject *type, PyObject *args, P
188221
return nullptr;
189222
}
190223

191-
// Ensure that the base __init__ function(s) were called
192-
values_and_holders vhs(self);
193-
for (const auto &vh : vhs) {
194-
if (!vh.holder_constructed() && !vhs.is_redundant_value_and_holder(vh)) {
195-
PyErr_Format(PyExc_TypeError,
196-
"%.200s.__init__() must be called when overriding __init__",
197-
get_fully_qualified_tp_name(vh.type->type).c_str());
198-
Py_DECREF(self);
199-
return nullptr;
200-
}
224+
if (!ensure_base_init_functions_were_called(self)) {
225+
Py_DECREF(self);
226+
return nullptr;
201227
}
202228

203229
return self;
204230
}
231+
#endif
205232

206233
/// Cleanup the type-info for a pybind11-registered type.
207234
extern "C" inline void pybind11_meta_dealloc(PyObject *obj) {
@@ -268,7 +295,9 @@ inline PyTypeObject *make_default_metaclass() {
268295
type->tp_base = type_incref(&PyType_Type);
269296
type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE;
270297

298+
#if defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS)
271299
type->tp_call = pybind11_meta_call;
300+
#endif
272301

273302
type->tp_setattro = pybind11_meta_setattro;
274303
type->tp_getattro = pybind11_meta_getattro;
@@ -340,6 +369,33 @@ inline bool deregister_instance(instance *self, void *valptr, const type_info *t
340369
return ret;
341370
}
342371

372+
#if defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT)
373+
374+
using derived_tp_init_registry_type = std::unordered_map<PyTypeObject *, initproc>;
375+
376+
inline derived_tp_init_registry_type *derived_tp_init_registry() {
377+
// Intentionally leak the unordered_map:
378+
// https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables
379+
static auto *singleton = new derived_tp_init_registry_type();
380+
return singleton;
381+
}
382+
383+
extern "C" inline int tp_init_with_safety_checks(PyObject *self, PyObject *args, PyObject *kw) {
384+
assert(PyType_Check(self) == 0);
385+
const auto derived_tp_init = derived_tp_init_registry()->find(Py_TYPE(self));
386+
if (derived_tp_init == derived_tp_init_registry()->end()) {
387+
pybind11_fail("FATAL: Internal consistency check failed at " __FILE__
388+
":" PYBIND11_TOSTRING(__LINE__));
389+
}
390+
int status = (*derived_tp_init->second)(self, args, kw);
391+
if (status == 0 && !ensure_base_init_functions_were_called(self)) {
392+
return -1; // No Py_DECREF here.
393+
}
394+
return status;
395+
}
396+
397+
#endif // PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT
398+
343399
/// Instance creation function for all pybind11 types. It allocates the internal instance layout
344400
/// for holding C++ objects and holders. Allocation is done lazily (the first time the instance is
345401
/// cast to a reference or pointer), and initialization is done by an `__init__` function.
@@ -360,11 +416,7 @@ inline PyObject *make_new_instance(PyTypeObject *type) {
360416
return self;
361417
}
362418

363-
/// Instance creation function for all pybind11 types. It only allocates space for the
364-
/// C++ object, but doesn't call the constructor -- an `__init__` function must do that.
365-
extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) {
366-
return make_new_instance(type);
367-
}
419+
extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *);
368420

369421
/// An `__init__` function constructs the C++ object. Users should provide at least one
370422
/// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the

include/pybind11/pybind11.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,6 +1207,33 @@ class cpp_function : public function {
12071207
}
12081208
};
12091209

1210+
PYBIND11_NAMESPACE_BEGIN(detail)
1211+
1212+
/// Instance creation function for all pybind11 types. It only allocates space for the
1213+
/// C++ object, but doesn't call the constructor -- an `__init__` function must do that.
1214+
extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) {
1215+
#if defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT)
1216+
if (type->tp_init != pybind11_object_init && type->tp_init != tp_init_with_safety_checks
1217+
&& derived_tp_init_registry()->count(type) == 0) {
1218+
weakref((PyObject *) type, cpp_function([type](handle wr) {
1219+
auto num_erased = derived_tp_init_registry()->erase(type);
1220+
if (num_erased != 1) {
1221+
pybind11_fail("FATAL: Internal consistency check failed at " __FILE__
1222+
":" PYBIND11_TOSTRING(__LINE__) ": num_erased="
1223+
+ std::to_string(num_erased));
1224+
}
1225+
wr.dec_ref();
1226+
}))
1227+
.release();
1228+
(*derived_tp_init_registry())[type] = type->tp_init;
1229+
type->tp_init = tp_init_with_safety_checks;
1230+
}
1231+
#endif // PYBIND11_INIT_SAFETY_CHECKS_VIA_INTERCEPTING_TP_INIT
1232+
return make_new_instance(type);
1233+
}
1234+
1235+
PYBIND11_NAMESPACE_END(detail)
1236+
12101237
/// Wrapper for Python extension modules
12111238
class module_ : public object {
12121239
public:

tests/test_python_multiple_inheritance.cpp

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ namespace test_python_multiple_inheritance {
55
// Copied from:
66
// https://github.com/google/clif/blob/5718e4d0807fd3b6a8187dde140069120b81ecef/clif/testing/python_multiple_inheritance.h
77

8+
template <int> // Using int as a trick to easily generate a series of types.
89
struct CppBase {
910
explicit CppBase(int value) : base_value(value) {}
1011
int get_base_value() const { return base_value; }
@@ -14,32 +15,45 @@ struct CppBase {
1415
int base_value;
1516
};
1617

17-
struct CppDrvd : CppBase {
18-
explicit CppDrvd(int value) : CppBase(value), drvd_value(value * 3) {}
18+
template <int SerNo>
19+
struct CppDrvd : CppBase<SerNo> {
20+
explicit CppDrvd(int value) : CppBase<SerNo>(value), drvd_value(value * 3) {}
1921
int get_drvd_value() const { return drvd_value; }
2022
void reset_drvd_value(int new_value) { drvd_value = new_value; }
2123

22-
int get_base_value_from_drvd() const { return get_base_value(); }
23-
void reset_base_value_from_drvd(int new_value) { reset_base_value(new_value); }
24+
int get_base_value_from_drvd() const { return CppBase<SerNo>::get_base_value(); }
25+
void reset_base_value_from_drvd(int new_value) { CppBase<SerNo>::reset_base_value(new_value); }
2426

2527
private:
2628
int drvd_value;
2729
};
2830

31+
template <int SerNo, typename... Extra>
32+
void wrap_classes(py::module_ &m, const char *name_base, const char *name_drvd, Extra... extra) {
33+
py::class_<CppBase<SerNo>>(m, name_base, std::forward<Extra>(extra)...)
34+
.def(py::init<int>())
35+
.def("get_base_value", &CppBase<SerNo>::get_base_value)
36+
.def("reset_base_value", &CppBase<SerNo>::reset_base_value);
37+
38+
py::class_<CppDrvd<SerNo>, CppBase<SerNo>>(m, name_drvd, std::forward<Extra>(extra)...)
39+
.def(py::init<int>())
40+
.def("get_drvd_value", &CppDrvd<SerNo>::get_drvd_value)
41+
.def("reset_drvd_value", &CppDrvd<SerNo>::reset_drvd_value)
42+
.def("get_base_value_from_drvd", &CppDrvd<SerNo>::get_base_value_from_drvd)
43+
.def("reset_base_value_from_drvd", &CppDrvd<SerNo>::reset_base_value_from_drvd);
44+
}
45+
2946
} // namespace test_python_multiple_inheritance
3047

3148
TEST_SUBMODULE(python_multiple_inheritance, m) {
3249
using namespace test_python_multiple_inheritance;
33-
34-
py::class_<CppBase>(m, "CppBase")
35-
.def(py::init<int>())
36-
.def("get_base_value", &CppBase::get_base_value)
37-
.def("reset_base_value", &CppBase::reset_base_value);
38-
39-
py::class_<CppDrvd, CppBase>(m, "CppDrvd")
40-
.def(py::init<int>())
41-
.def("get_drvd_value", &CppDrvd::get_drvd_value)
42-
.def("reset_drvd_value", &CppDrvd::reset_drvd_value)
43-
.def("get_base_value_from_drvd", &CppDrvd::get_base_value_from_drvd)
44-
.def("reset_base_value_from_drvd", &CppDrvd::reset_base_value_from_drvd);
50+
wrap_classes<0>(m, "CppBase0", "CppDrvd0");
51+
wrap_classes<1>(m, "CppBase1", "CppDrvd1", py::metaclass((PyObject *) &PyType_Type));
52+
53+
m.attr("if_defined_PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS") =
54+
#if defined(PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS)
55+
true;
56+
#else
57+
false;
58+
#endif
4559
}
Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,78 @@
11
# Adapted from:
2-
# https://github.com/google/clif/blob/5718e4d0807fd3b6a8187dde140069120b81ecef/clif/testing/python/python_multiple_inheritance_test.py
2+
# https://github.com/google/clif/blob/7d388e1de7db5beeb3d7429c18a2776d8188f44f/clif/testing/python/python_multiple_inheritance_test.py
3+
4+
import pytest
35

46
from pybind11_tests import python_multiple_inheritance as m
57

8+
#
9+
# Using default py::metaclass() (used with py::class_<> for CppBase0, CppDrvd0):
10+
#
11+
12+
13+
class PC0(m.CppBase0):
14+
pass
15+
16+
17+
class PPCC0(PC0, m.CppDrvd0):
18+
pass
19+
20+
21+
class PCExplicitInitWithSuper0(m.CppBase0):
22+
def __init__(self, value):
23+
super().__init__(value + 1)
24+
25+
26+
class PCExplicitInitMissingSuper0(m.CppBase0):
27+
def __init__(self, value):
28+
del value
29+
30+
31+
class PCExplicitInitMissingSuperB0(m.CppBase0):
32+
def __init__(self, value):
33+
del value
34+
35+
36+
#
37+
# Using py::metaclass((PyObject *) &PyType_Type) (used with py::class_<> for CppBase1, CppDrvd1):
38+
# COPY-PASTE block from above, replace 0 with 1:
39+
#
640

7-
class PC(m.CppBase):
41+
42+
class PC1(m.CppBase1):
843
pass
944

1045

11-
class PPCC(PC, m.CppDrvd):
46+
class PPCC1(PC1, m.CppDrvd1):
1247
pass
1348

1449

15-
def test_PC():
16-
d = PC(11)
50+
class PCExplicitInitWithSuper1(m.CppBase1):
51+
def __init__(self, value):
52+
super().__init__(value + 1)
53+
54+
55+
class PCExplicitInitMissingSuper1(m.CppBase1):
56+
def __init__(self, value):
57+
del value
58+
59+
60+
class PCExplicitInitMissingSuperB1(m.CppBase1):
61+
def __init__(self, value):
62+
del value
63+
64+
65+
@pytest.mark.parametrize(("pc_type"), [PC0, PC1])
66+
def test_PC(pc_type):
67+
d = pc_type(11)
1768
assert d.get_base_value() == 11
1869
d.reset_base_value(13)
1970
assert d.get_base_value() == 13
2071

2172

22-
def test_PPCC():
23-
d = PPCC(11)
73+
@pytest.mark.parametrize(("ppcc_type"), [PPCC0, PPCC1])
74+
def test_PPCC(ppcc_type):
75+
d = ppcc_type(11)
2476
assert d.get_drvd_value() == 33
2577
d.reset_drvd_value(55)
2678
assert d.get_drvd_value() == 55
@@ -33,3 +85,54 @@ def test_PPCC():
3385
d.reset_base_value_from_drvd(30)
3486
assert d.get_base_value() == 30
3587
assert d.get_base_value_from_drvd() == 30
88+
89+
90+
@pytest.mark.parametrize(
91+
("pc_type"), [PCExplicitInitWithSuper0, PCExplicitInitWithSuper1]
92+
)
93+
def testPCExplicitInitWithSuper(pc_type):
94+
d = pc_type(14)
95+
assert d.get_base_value() == 15
96+
97+
98+
@pytest.mark.parametrize(
99+
("derived_type"),
100+
[
101+
PCExplicitInitMissingSuper0,
102+
PCExplicitInitMissingSuperB0,
103+
PCExplicitInitMissingSuper1,
104+
PCExplicitInitMissingSuperB1,
105+
],
106+
)
107+
def testPCExplicitInitMissingSuper(derived_type):
108+
if (
109+
m.if_defined_PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS
110+
and derived_type
111+
in (
112+
PCExplicitInitMissingSuper1,
113+
PCExplicitInitMissingSuperB1,
114+
)
115+
):
116+
pytest.skip(
117+
"PYBIND11_INIT_SAFETY_CHECKS_VIA_DEFAULT_PYBIND11_METACLASS is defined"
118+
)
119+
with pytest.raises(TypeError) as excinfo:
120+
derived_type(0)
121+
assert str(excinfo.value).endswith(
122+
".__init__() must be called when overriding __init__"
123+
)
124+
125+
126+
def test_derived_tp_init_registry_weakref_based_cleanup():
127+
def nested_function(i):
128+
class NestedClass(m.CppBase0):
129+
def __init__(self, value):
130+
super().__init__(value + 3)
131+
132+
d1 = NestedClass(i + 7)
133+
d2 = NestedClass(i + 8)
134+
return (d1.get_base_value(), d2.get_base_value())
135+
136+
for _ in range(100):
137+
assert nested_function(0) == (10, 11)
138+
assert nested_function(3) == (13, 14)

0 commit comments

Comments
 (0)