Skip to content

Commit 4638e01

Browse files
committed
Modify py::trampoline_self_life_support semantics: if trampoline class does not inherit from this class, preserve established Inheritance Slicing behavior.
rwgk reached this point with the help of ChatGPT: * https://chatgpt.com/share/68056498-7d94-8008-8ff0-232e2aba451c The only production code change in this commit is: ``` diff --git a/include/pybind11/detail/type_caster_base.h b/include/pybind11/detail/type_caster_base.h index d4f9a41..f3d4530 100644 --- a/include/pybind11/detail/type_caster_base.h +++ b/include/pybind11/detail/type_caster_base.h @@ -776,6 +776,14 @@ struct load_helper : value_and_holder_helper { if (released_ptr) { return std::shared_ptr<T>(released_ptr, type_raw_ptr); } + auto *self_life_support + = dynamic_raw_ptr_cast_if_possible<trampoline_self_life_support>(type_raw_ptr); + if (self_life_support == nullptr) { + std::shared_ptr<void> void_shd_ptr = hld.template as_shared_ptr<void>(); + std::shared_ptr<T> to_be_released(void_shd_ptr, type_raw_ptr); + vptr_gd_ptr->released_ptr = to_be_released; + return to_be_released; + } std::shared_ptr<T> to_be_released( type_raw_ptr, shared_ptr_trampoline_self_life_support(loaded_v_h.inst)); vptr_gd_ptr->released_ptr = to_be_released; ```
1 parent 1c0b700 commit 4638e01

7 files changed

+76
-6
lines changed

include/pybind11/detail/type_caster_base.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,14 @@ struct load_helper : value_and_holder_helper {
776776
if (released_ptr) {
777777
return std::shared_ptr<T>(released_ptr, type_raw_ptr);
778778
}
779+
auto *self_life_support
780+
= dynamic_raw_ptr_cast_if_possible<trampoline_self_life_support>(type_raw_ptr);
781+
if (self_life_support == nullptr) {
782+
std::shared_ptr<void> void_shd_ptr = hld.template as_shared_ptr<void>();
783+
std::shared_ptr<T> to_be_released(void_shd_ptr, type_raw_ptr);
784+
vptr_gd_ptr->released_ptr = to_be_released;
785+
return to_be_released;
786+
}
779787
std::shared_ptr<T> to_be_released(
780788
type_raw_ptr, shared_ptr_trampoline_self_life_support(loaded_v_h.inst));
781789
vptr_gd_ptr->released_ptr = to_be_released;

include/pybind11/pybind11.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3416,14 +3416,23 @@ PYBIND11_NAMESPACE_END(detail)
34163416
template <class T>
34173417
function get_override(const T *this_ptr, const char *name) {
34183418
auto *tinfo = detail::get_type_info(typeid(T));
3419+
fflush(stderr);
3420+
printf("\nLOOOK get_override tinfo truthy = %s\n", tinfo ? "YES" : "NO");
3421+
fflush(stdout);
34193422
return tinfo ? detail::get_type_override(this_ptr, tinfo, name) : function();
34203423
}
34213424

34223425
#define PYBIND11_OVERRIDE_IMPL(ret_type, cname, name, ...) \
34233426
do { \
34243427
pybind11::gil_scoped_acquire gil; \
3428+
fflush(stderr); \
3429+
printf("\nLOOOK BEFORE static_cast<const cname *>(this)\n"); \
3430+
fflush(stdout); \
34253431
pybind11::function override \
34263432
= pybind11::get_override(static_cast<const cname *>(this), name); \
3433+
fflush(stderr); \
3434+
printf("\nLOOOK AFTER static_cast<const cname *>(this)\n"); \
3435+
fflush(stdout); \
34273436
if (override) { \
34283437
auto o = override(__VA_ARGS__); \
34293438
PYBIND11_WARNING_PUSH \

tests/test_class_sh_trampoline_shared_ptr_cpp_arg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ struct SpBase {
2828

2929
std::shared_ptr<SpBase> pass_through_shd_ptr(const std::shared_ptr<SpBase> &obj) { return obj; }
3030

31-
struct PySpBase : SpBase {
31+
struct PySpBase : SpBase, py::trampoline_self_life_support {
3232
using SpBase::SpBase;
3333
bool is_base_used() override { PYBIND11_OVERRIDE(bool, SpBase, is_base_used); }
3434
};

tests/test_class_sh_trampoline_weak_ptr.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,15 @@ struct VirtBase {
1414
virtual int get_code() { return 100; }
1515
};
1616

17-
struct PyVirtBase : VirtBase, py::trampoline_self_life_support {
17+
struct PyVirtBase : VirtBase /*, py::trampoline_self_life_support */ {
1818
using VirtBase::VirtBase;
1919
int get_code() override { PYBIND11_OVERRIDE(int, VirtBase, get_code); }
20+
21+
~PyVirtBase() override {
22+
fflush(stderr);
23+
printf("\nLOOOK ~PyVirtBase()\n");
24+
fflush(stdout);
25+
}
2026
};
2127

2228
struct WpOwner {
@@ -34,6 +40,10 @@ struct WpOwner {
3440
std::weak_ptr<VirtBase> wp;
3541
};
3642

43+
std::shared_ptr<VirtBase> pass_through_sp_VirtBase(const std::shared_ptr<VirtBase> &sp) {
44+
return sp;
45+
}
46+
3747
} // namespace class_sh_trampoline_weak_ptr
3848
} // namespace pybind11_tests
3949

@@ -48,4 +58,6 @@ TEST_SUBMODULE(class_sh_trampoline_weak_ptr, m) {
4858
.def(py::init<>())
4959
.def("set_wp", &WpOwner::set_wp)
5060
.def("get_code", &WpOwner::get_code);
61+
62+
m.def("pass_through_sp_VirtBase", pass_through_sp_VirtBase);
5163
}

tests/test_class_sh_trampoline_weak_ptr.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import gc
4+
35
import pytest
46

57
import env
@@ -20,12 +22,22 @@ def test_weak_ptr_owner(vtype, expected_code):
2022
assert obj.get_code() == expected_code
2123

2224
wpo.set_wp(obj)
23-
if vtype is m.VirtBase:
24-
assert wpo.get_code() == expected_code
25-
else:
26-
assert wpo.get_code() == -999 # THIS NEEDS FIXING (issue #5623)
25+
assert wpo.get_code() == expected_code
2726

2827
del obj
2928
if env.PYPY or env.GRAALPY:
3029
pytest.skip("Cannot reliably trigger GC")
3130
assert wpo.get_code() == -999
31+
32+
33+
@pytest.mark.parametrize(("vtype", "expected_code"), [(m.VirtBase, 100), (PyDrvd, 200)])
34+
def test_pass_through_sp_VirtBase(vtype, expected_code):
35+
obj = vtype()
36+
ptr = m.pass_through_sp_VirtBase(obj)
37+
print("\nLOOOK BEFORE del obj", flush=True)
38+
del obj
39+
print("\nLOOOK AFTER del obj", flush=True)
40+
gc.collect()
41+
print("\nLOOOK AFTER gc.collect()", flush=True)
42+
assert ptr.get_code() == expected_code
43+
print("\nLOOOK AFTER ptr.get_code()", flush=True)

tests/test_class_sp_trampoline_weak_ptr.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ struct VirtBase {
1717
struct PyVirtBase : VirtBase, py::trampoline_self_life_support {
1818
using VirtBase::VirtBase;
1919
int get_code() override { PYBIND11_OVERRIDE(int, VirtBase, get_code); }
20+
21+
~PyVirtBase() override {
22+
fflush(stderr);
23+
printf("\nLOOOK ~PyVirtBase()\n");
24+
fflush(stdout);
25+
}
2026
};
2127

2228
struct WpOwner {
@@ -48,6 +54,10 @@ struct SpOwner {
4854
std::shared_ptr<VirtBase> sp;
4955
};
5056

57+
std::shared_ptr<VirtBase> pass_through_sp_VirtBase(const std::shared_ptr<VirtBase> &sp) {
58+
return sp;
59+
}
60+
5161
} // namespace class_sp_trampoline_weak_ptr
5262
} // namespace pybind11_tests
5363

@@ -67,4 +77,6 @@ TEST_SUBMODULE(class_sp_trampoline_weak_ptr, m) {
6777
.def(py::init<>())
6878
.def("set_sp", &SpOwner::set_sp)
6979
.def("get_code", &SpOwner::get_code);
80+
81+
m.def("pass_through_sp_VirtBase", pass_through_sp_VirtBase);
7082
}

tests/test_class_sp_trampoline_weak_ptr.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import gc
4+
35
import pytest
46

57
import env
@@ -42,7 +44,9 @@ def test_with_sp_owner(vtype, expected_code):
4244
del obj
4345
if env.PYPY or env.GRAALPY:
4446
pytest.skip("Cannot reliably trigger GC")
47+
print("\nLOOOK BEFORE spo.get_code() AFTER del obj", flush=True)
4548
assert spo.get_code() == 100 # Inheritance slicing (issue #1333)
49+
print("\nLOOOK AFTER spo.get_code() AFTER del obj", flush=True)
4650

4751

4852
@pytest.mark.parametrize(("vtype", "expected_code"), [(m.VirtBase, 100), (PyDrvd, 200)])
@@ -67,3 +71,16 @@ def test_with_sp_and_wp_owners(vtype, expected_code):
6771

6872
del spo
6973
assert wpo.get_code() == -999
74+
75+
76+
@pytest.mark.parametrize(("vtype", "expected_code"), [(m.VirtBase, 100), (PyDrvd, 200)])
77+
def test_pass_through_sp_VirtBase(vtype, expected_code):
78+
obj = vtype()
79+
ptr = m.pass_through_sp_VirtBase(obj)
80+
print("\nLOOOK BEFORE del obj", flush=True)
81+
del obj
82+
print("\nLOOOK AFTER del obj", flush=True)
83+
gc.collect()
84+
print("\nLOOOK AFTER gc.collect()", flush=True)
85+
assert ptr.get_code() == expected_code
86+
print("\nLOOOK AFTER ptr.get_code()", flush=True)

0 commit comments

Comments
 (0)