Skip to content

Commit 04d9f84

Browse files
authored
[smart_holder] Fix handling of const unique_ptr<T, D> & (do not disown). (#5332)
* Replace `unique_ptr_cref_roundtrip()` with `pass_unique_ptr_cref()`, `rtrn_unique_ptr_cref()` to make the current behavior obvious. * add in unique_ptr_storage, unique_ptr_storage_deleter * Add shared_ptr_storage (with that disowning fails as expected). * Add load_as_const_unique_ptr() * Restore original struct_smart_holder.h * factor out `smart_holder::extract_deleter()` * Better error message. * Misc cleanup/tidying. * Use `re.match("ctor_arg(_MvCtor)*_MvCtor", ...)` for compatibility with MSVC, NVHPC, ICC * Add small comments. * Fix small, inconsequential oversight in test code. * Apply suggestion by @iwanders under PR #5334 * Remove `std::move()` in `smart_holder::extract_deleter()` * Add `static_assert()` following a suggestion by @iwanders under PR #5334
1 parent 0e49463 commit 04d9f84

7 files changed

+94
-17
lines changed

include/pybind11/cast.h

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,8 +1067,14 @@ struct move_only_holder_caster<
10671067
+ clean_type_id(typeinfo->cpptype->name()) + ")");
10681068
}
10691069

1070-
template <typename>
1071-
using cast_op_type = std::unique_ptr<type, deleter>;
1070+
template <typename T_>
1071+
using cast_op_type
1072+
= conditional_t<std::is_same<typename std::remove_volatile<T_>::type,
1073+
const std::unique_ptr<type, deleter> &>::value
1074+
|| std::is_same<typename std::remove_volatile<T_>::type,
1075+
const std::unique_ptr<const type, deleter> &>::value,
1076+
const std::unique_ptr<type, deleter> &,
1077+
std::unique_ptr<type, deleter>>;
10721078

10731079
explicit operator std::unique_ptr<type, deleter>() {
10741080
if (typeinfo->holder_enum_v == detail::holder_enum_t::smart_holder) {
@@ -1077,6 +1083,28 @@ struct move_only_holder_caster<
10771083
pybind11_fail("Expected to be UNREACHABLE: " __FILE__ ":" PYBIND11_TOSTRING(__LINE__));
10781084
}
10791085

1086+
explicit operator const std::unique_ptr<type, deleter> &() {
1087+
if (typeinfo->holder_enum_v == detail::holder_enum_t::smart_holder) {
1088+
// Get shared_ptr to ensure that the Python object is not disowned elsewhere.
1089+
shared_ptr_storage = sh_load_helper.load_as_shared_ptr(value);
1090+
// Build a temporary unique_ptr that is meant to never expire.
1091+
unique_ptr_storage = std::shared_ptr<std::unique_ptr<type, deleter>>(
1092+
new std::unique_ptr<type, deleter>{
1093+
sh_load_helper.template load_as_const_unique_ptr<deleter>(
1094+
shared_ptr_storage.get())},
1095+
[](std::unique_ptr<type, deleter> *ptr) {
1096+
if (!ptr) {
1097+
pybind11_fail("FATAL: `const std::unique_ptr<T, D> &` was disowned "
1098+
"(EXPECT UNDEFINED BEHAVIOR).");
1099+
}
1100+
(void) ptr->release();
1101+
delete ptr;
1102+
});
1103+
return *unique_ptr_storage;
1104+
}
1105+
pybind11_fail("Expected to be UNREACHABLE: " __FILE__ ":" PYBIND11_TOSTRING(__LINE__));
1106+
}
1107+
10801108
bool try_implicit_casts(handle src, bool convert) {
10811109
for (auto &cast : typeinfo->implicit_casts) {
10821110
move_only_holder_caster sub_caster(*cast.first);
@@ -1097,6 +1125,8 @@ struct move_only_holder_caster<
10971125
static bool try_direct_conversions(handle) { return false; }
10981126

10991127
smart_holder_type_caster_support::load_helper<remove_cv_t<type>> sh_load_helper; // Const2Mutbl
1128+
std::shared_ptr<type> shared_ptr_storage; // Serves as a pseudo lock.
1129+
std::shared_ptr<std::unique_ptr<type, deleter>> unique_ptr_storage;
11001130
};
11011131

11021132
#endif // PYBIND11_HAS_INTERNALS_WITH_SMART_HOLDER_SUPPORT

include/pybind11/detail/struct_smart_holder.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,15 +234,17 @@ struct smart_holder {
234234
// Caller is responsible for precondition: ensure_compatible_rtti_uqp_del<T, D>() must succeed.
235235
template <typename T, typename D>
236236
std::unique_ptr<D> extract_deleter(const char *context) const {
237-
auto *gd = std::get_deleter<guarded_delete>(vptr);
237+
const auto *gd = std::get_deleter<guarded_delete>(vptr);
238238
if (gd && gd->use_del_fun) {
239239
const auto &custom_deleter_ptr = gd->del_fun.template target<custom_deleter<T, D>>();
240240
if (custom_deleter_ptr == nullptr) {
241241
throw std::runtime_error(
242242
std::string("smart_holder::extract_deleter() precondition failure (") + context
243243
+ ").");
244244
}
245-
return std::unique_ptr<D>(new D(std::move(custom_deleter_ptr->deleter)));
245+
static_assert(std::is_copy_constructible<D>::value,
246+
"Required for compatibility with smart_holder functionality.");
247+
return std::unique_ptr<D>(new D(custom_deleter_ptr->deleter));
246248
}
247249
return nullptr;
248250
}

include/pybind11/detail/type_caster_base.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,19 @@ struct load_helper : value_and_holder_helper {
814814

815815
return result;
816816
}
817+
818+
// This assumes load_as_shared_ptr succeeded(), and the returned shared_ptr is still alive.
819+
// The returned unique_ptr is meant to never expire (the behavior is undefined otherwise).
820+
template <typename D>
821+
std::unique_ptr<T, D>
822+
load_as_const_unique_ptr(T *raw_type_ptr, const char *context = "load_as_const_unique_ptr") {
823+
if (!have_holder()) {
824+
return unique_with_deleter<T, D>(nullptr, std::unique_ptr<D>());
825+
}
826+
holder().template ensure_compatible_rtti_uqp_del<T, D>(context);
827+
return unique_with_deleter<T, D>(
828+
raw_type_ptr, std::move(holder().template extract_deleter<T, D>(context)));
829+
}
817830
};
818831

819832
PYBIND11_NAMESPACE_END(smart_holder_type_caster_support)

tests/test_class_sh_basic.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,17 @@ std::string get_mtxt(atyp const &obj) { return obj.mtxt; }
120120
std::ptrdiff_t get_ptr(atyp const &obj) { return reinterpret_cast<std::ptrdiff_t>(&obj); }
121121

122122
std::unique_ptr<atyp> unique_ptr_roundtrip(std::unique_ptr<atyp> obj) { return obj; }
123+
124+
std::string pass_unique_ptr_cref(const std::unique_ptr<atyp> &obj) { return obj->mtxt; }
125+
126+
const std::unique_ptr<atyp> &rtrn_unique_ptr_cref(const std::string &mtxt) {
127+
static std::unique_ptr<atyp> obj{new atyp{"static_ctor_arg"}};
128+
if (!mtxt.empty()) {
129+
obj->mtxt = mtxt;
130+
}
131+
return obj;
132+
}
133+
123134
const std::unique_ptr<atyp> &unique_ptr_cref_roundtrip(const std::unique_ptr<atyp> &obj) {
124135
return obj;
125136
}
@@ -217,6 +228,9 @@ TEST_SUBMODULE(class_sh_basic, m) {
217228
m.def("get_ptr", get_ptr); // pass_cref
218229

219230
m.def("unique_ptr_roundtrip", unique_ptr_roundtrip); // pass_uqmp, rtrn_uqmp
231+
232+
m.def("pass_unique_ptr_cref", pass_unique_ptr_cref);
233+
m.def("rtrn_unique_ptr_cref", rtrn_unique_ptr_cref);
220234
m.def("unique_ptr_cref_roundtrip", unique_ptr_cref_roundtrip);
221235

222236
py::classh<SharedPtrStash>(m, "SharedPtrStash")

tests/test_class_sh_basic.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -151,19 +151,31 @@ def test_unique_ptr_roundtrip(num_round_trips=1000):
151151
id_orig = id_rtrn
152152

153153

154-
# This currently fails, because a unique_ptr is always loaded by value
155-
# due to pybind11/detail/smart_holder_type_casters.h:689
156-
# I think, we need to provide more cast operators.
157-
@pytest.mark.skip()
158-
def test_unique_ptr_cref_roundtrip():
154+
def test_pass_unique_ptr_cref():
155+
obj = m.atyp("ctor_arg")
156+
assert re.match("ctor_arg(_MvCtor)*_MvCtor", m.get_mtxt(obj))
157+
assert re.match("ctor_arg(_MvCtor)*_MvCtor", m.pass_unique_ptr_cref(obj))
158+
assert re.match("ctor_arg(_MvCtor)*_MvCtor", m.get_mtxt(obj))
159+
160+
161+
def test_rtrn_unique_ptr_cref():
162+
obj0 = m.rtrn_unique_ptr_cref("")
163+
assert m.get_mtxt(obj0) == "static_ctor_arg"
164+
obj1 = m.rtrn_unique_ptr_cref("passed_mtxt_1")
165+
assert m.get_mtxt(obj1) == "passed_mtxt_1"
166+
assert m.get_mtxt(obj0) == "passed_mtxt_1"
167+
assert obj0 is obj1
168+
169+
170+
def test_unique_ptr_cref_roundtrip(num_round_trips=1000):
171+
# Multiple roundtrips to stress-test implementation.
159172
orig = m.atyp("passenger")
160-
id_orig = id(orig)
161173
mtxt_orig = m.get_mtxt(orig)
162-
163-
recycled = m.unique_ptr_cref_roundtrip(orig)
164-
assert m.get_mtxt(orig) == mtxt_orig
165-
assert m.get_mtxt(recycled) == mtxt_orig
166-
assert id(recycled) == id_orig
174+
recycled = orig
175+
for _ in range(num_round_trips):
176+
recycled = m.unique_ptr_cref_roundtrip(recycled)
177+
assert recycled is orig
178+
assert m.get_mtxt(recycled) == mtxt_orig
167179

168180

169181
@pytest.mark.parametrize(

tests/test_class_sh_trampoline_shared_from_this.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ long pass_shared_ptr(const std::shared_ptr<Sft> &obj) {
8787
return sft.use_count();
8888
}
8989

90-
void pass_unique_ptr_cref(const std::unique_ptr<Sft> &) {
90+
std::string pass_unique_ptr_cref(const std::unique_ptr<Sft> &obj) {
91+
return obj ? obj->history : "<NULLPTR>";
92+
}
93+
void pass_unique_ptr_rref(std::unique_ptr<Sft> &&) {
9194
throw std::runtime_error("Expected to not be reached.");
9295
}
9396

@@ -138,6 +141,7 @@ TEST_SUBMODULE(class_sh_trampoline_shared_from_this, m) {
138141
m.def("use_count", use_count);
139142
m.def("pass_shared_ptr", pass_shared_ptr);
140143
m.def("pass_unique_ptr_cref", pass_unique_ptr_cref);
144+
m.def("pass_unique_ptr_rref", pass_unique_ptr_rref);
141145
m.def("make_pure_cpp_sft_raw_ptr", make_pure_cpp_sft_raw_ptr);
142146
m.def("make_pure_cpp_sft_unq_ptr", make_pure_cpp_sft_unq_ptr);
143147
m.def("make_pure_cpp_sft_shd_ptr", make_pure_cpp_sft_shd_ptr);

tests/test_class_sh_trampoline_shared_from_this.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,10 @@ def test_pass_released_shared_ptr_as_unique_ptr():
137137
obj = PySft("PySft")
138138
stash1 = m.SftSharedPtrStash(1)
139139
stash1.Add(obj) # Releases shared_ptr to C++.
140+
assert m.pass_unique_ptr_cref(obj) == "PySft_Stash1Add"
141+
assert obj.history == "PySft_Stash1Add"
140142
with pytest.raises(ValueError) as exc_info:
141-
m.pass_unique_ptr_cref(obj)
143+
m.pass_unique_ptr_rref(obj)
142144
assert str(exc_info.value) == (
143145
"Python instance is currently owned by a std::shared_ptr."
144146
)

0 commit comments

Comments
 (0)