Skip to content

Commit 58802de

Browse files
authored
perf: Add object rvalue overload for accessors. Enables reference stealing (#3970)
* Add object rvalue overload for accessors. Enables reference stealing * Fix comments * Fix more comment typos * Fix bug * reorder declarations for clarity * fix another perf bug * should be static * future proof operator overloads * Fix perfect forwarding * Add a couple of tests * Remove errant include * Improve test documentation * Add dict test * add object attr tests * Optimize STL map caster and cleanup enum * Reorder to match declarations * adjust increfs * Remove comment * revert value change * add missing move
1 parent 9f7b3f7 commit 58802de

File tree

5 files changed

+97
-12
lines changed

5 files changed

+97
-12
lines changed

include/pybind11/pybind11.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2069,12 +2069,12 @@ struct enum_base {
20692069
str name(name_);
20702070
if (entries.contains(name)) {
20712071
std::string type_name = (std::string) str(m_base.attr("__name__"));
2072-
throw value_error(type_name + ": element \"" + std::string(name_)
2072+
throw value_error(std::move(type_name) + ": element \"" + std::string(name_)
20732073
+ "\" already exists!");
20742074
}
20752075

20762076
entries[name] = std::make_pair(value, doc);
2077-
m_base.attr(name) = value;
2077+
m_base.attr(std::move(name)) = std::move(value);
20782078
}
20792079

20802080
PYBIND11_NOINLINE void export_values() {
@@ -2610,7 +2610,7 @@ PYBIND11_NOINLINE void print(const tuple &args, const dict &kwargs) {
26102610
}
26112611

26122612
auto write = file.attr("write");
2613-
write(line);
2613+
write(std::move(line));
26142614
write(kwargs.contains("end") ? kwargs["end"] : str("\n"));
26152615

26162616
if (kwargs.contains("flush") && kwargs["flush"].cast<bool>()) {

include/pybind11/pytypes.h

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ class object_api : public pyobject_tag {
8585
or `object` subclass causes a call to ``__setitem__``.
8686
\endrst */
8787
item_accessor operator[](handle key) const;
88-
/// See above (the only difference is that they key is provided as a string literal)
88+
/// See above (the only difference is that the key's reference is stolen)
89+
item_accessor operator[](object &&key) const;
90+
/// See above (the only difference is that the key is provided as a string literal)
8991
item_accessor operator[](const char *key) const;
9092

9193
/** \rst
@@ -95,7 +97,9 @@ class object_api : public pyobject_tag {
9597
or `object` subclass causes a call to ``setattr``.
9698
\endrst */
9799
obj_attr_accessor attr(handle key) const;
98-
/// See above (the only difference is that they key is provided as a string literal)
100+
/// See above (the only difference is that the key's reference is stolen)
101+
obj_attr_accessor attr(object &&key) const;
102+
/// See above (the only difference is that the key is provided as a string literal)
99103
str_attr_accessor attr(const char *key) const;
100104

101105
/** \rst
@@ -684,7 +688,7 @@ class accessor : public object_api<accessor<Policy>> {
684688
}
685689
template <typename T>
686690
void operator=(T &&value) & {
687-
get_cache() = reinterpret_borrow<object>(object_or_cast(std::forward<T>(value)));
691+
get_cache() = ensure_object(object_or_cast(std::forward<T>(value)));
688692
}
689693

690694
template <typename T = Policy>
@@ -712,6 +716,9 @@ class accessor : public object_api<accessor<Policy>> {
712716
}
713717

714718
private:
719+
static object ensure_object(object &&o) { return std::move(o); }
720+
static object ensure_object(handle h) { return reinterpret_borrow<object>(h); }
721+
715722
object &get_cache() const {
716723
if (!cache) {
717724
cache = Policy::get(obj, key);
@@ -1711,7 +1718,10 @@ class tuple : public object {
17111718
size_t size() const { return (size_t) PyTuple_Size(m_ptr); }
17121719
bool empty() const { return size() == 0; }
17131720
detail::tuple_accessor operator[](size_t index) const { return {*this, index}; }
1714-
detail::item_accessor operator[](handle h) const { return object::operator[](h); }
1721+
template <typename T, detail::enable_if_t<detail::is_pyobject<T>::value, int> = 0>
1722+
detail::item_accessor operator[](T &&o) const {
1723+
return object::operator[](std::forward<T>(o));
1724+
}
17151725
detail::tuple_iterator begin() const { return {*this, 0}; }
17161726
detail::tuple_iterator end() const { return {*this, PyTuple_GET_SIZE(m_ptr)}; }
17171727
};
@@ -1771,7 +1781,10 @@ class sequence : public object {
17711781
}
17721782
bool empty() const { return size() == 0; }
17731783
detail::sequence_accessor operator[](size_t index) const { return {*this, index}; }
1774-
detail::item_accessor operator[](handle h) const { return object::operator[](h); }
1784+
template <typename T, detail::enable_if_t<detail::is_pyobject<T>::value, int> = 0>
1785+
detail::item_accessor operator[](T &&o) const {
1786+
return object::operator[](std::forward<T>(o));
1787+
}
17751788
detail::sequence_iterator begin() const { return {*this, 0}; }
17761789
detail::sequence_iterator end() const { return {*this, PySequence_Size(m_ptr)}; }
17771790
};
@@ -1790,7 +1803,10 @@ class list : public object {
17901803
size_t size() const { return (size_t) PyList_Size(m_ptr); }
17911804
bool empty() const { return size() == 0; }
17921805
detail::list_accessor operator[](size_t index) const { return {*this, index}; }
1793-
detail::item_accessor operator[](handle h) const { return object::operator[](h); }
1806+
template <typename T, detail::enable_if_t<detail::is_pyobject<T>::value, int> = 0>
1807+
detail::item_accessor operator[](T &&o) const {
1808+
return object::operator[](std::forward<T>(o));
1809+
}
17941810
detail::list_iterator begin() const { return {*this, 0}; }
17951811
detail::list_iterator end() const { return {*this, PyList_GET_SIZE(m_ptr)}; }
17961812
template <typename T>
@@ -2090,6 +2106,10 @@ item_accessor object_api<D>::operator[](handle key) const {
20902106
return {derived(), reinterpret_borrow<object>(key)};
20912107
}
20922108
template <typename D>
2109+
item_accessor object_api<D>::operator[](object &&key) const {
2110+
return {derived(), std::move(key)};
2111+
}
2112+
template <typename D>
20932113
item_accessor object_api<D>::operator[](const char *key) const {
20942114
return {derived(), pybind11::str(key)};
20952115
}
@@ -2098,6 +2118,10 @@ obj_attr_accessor object_api<D>::attr(handle key) const {
20982118
return {derived(), reinterpret_borrow<object>(key)};
20992119
}
21002120
template <typename D>
2121+
obj_attr_accessor object_api<D>::attr(object &&key) const {
2122+
return {derived(), std::move(key)};
2123+
}
2124+
template <typename D>
21012125
str_attr_accessor object_api<D>::attr(const char *key) const {
21022126
return {derived(), key};
21032127
}

include/pybind11/stl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ struct map_caster {
128128
if (!key || !value) {
129129
return handle();
130130
}
131-
d[key] = value;
131+
d[std::move(key)] = std::move(value);
132132
}
133133
return d.release();
134134
}

tests/test_pytypes.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,4 +661,38 @@ TEST_SUBMODULE(pytypes, m) {
661661
double v = x.get_value();
662662
return v * v;
663663
});
664+
665+
m.def("tuple_rvalue_getter", [](const py::tuple &tup) {
666+
// tests accessing tuple object with rvalue int
667+
for (size_t i = 0; i < tup.size(); i++) {
668+
auto o = py::handle(tup[py::int_(i)]);
669+
if (!o) {
670+
throw py::value_error("tuple is malformed");
671+
}
672+
}
673+
return tup;
674+
});
675+
m.def("list_rvalue_getter", [](const py::list &l) {
676+
// tests accessing list with rvalue int
677+
for (size_t i = 0; i < l.size(); i++) {
678+
auto o = py::handle(l[py::int_(i)]);
679+
if (!o) {
680+
throw py::value_error("list is malformed");
681+
}
682+
}
683+
return l;
684+
});
685+
m.def("populate_dict_rvalue", [](int population) {
686+
auto d = py::dict();
687+
for (int i = 0; i < population; i++) {
688+
d[py::int_(i)] = py::int_(i);
689+
}
690+
return d;
691+
});
692+
m.def("populate_obj_str_attrs", [](py::object &o, int population) {
693+
for (int i = 0; i < population; i++) {
694+
o.attr(py::str(py::int_(i))) = py::str(py::int_(i));
695+
}
696+
return o;
697+
});
664698
}

tests/test_pytypes.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import contextlib
22
import sys
3+
import types
34

45
import pytest
56

@@ -320,8 +321,7 @@ def func(self, x, *args):
320321
def test_accessor_moves():
321322
inc_refs = m.accessor_moves()
322323
if inc_refs:
323-
# To be changed in PR #3970: [1, 0, 1, 0, ...]
324-
assert inc_refs == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
324+
assert inc_refs == [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]
325325
else:
326326
pytest.skip("Not defined: PYBIND11_HANDLE_REF_DEBUG")
327327

@@ -707,3 +707,30 @@ def test_implementation_details():
707707
def test_external_float_():
708708
r1 = m.square_float_(2.0)
709709
assert r1 == 4.0
710+
711+
712+
def test_tuple_rvalue_getter():
713+
pop = 1000
714+
tup = tuple(range(pop))
715+
m.tuple_rvalue_getter(tup)
716+
717+
718+
def test_list_rvalue_getter():
719+
pop = 1000
720+
my_list = list(range(pop))
721+
m.list_rvalue_getter(my_list)
722+
723+
724+
def test_populate_dict_rvalue():
725+
pop = 1000
726+
my_dict = {i: i for i in range(pop)}
727+
assert m.populate_dict_rvalue(pop) == my_dict
728+
729+
730+
def test_populate_obj_str_attrs():
731+
pop = 1000
732+
o = types.SimpleNamespace(**{str(i): i for i in range(pop)})
733+
new_o = m.populate_obj_str_attrs(o, pop)
734+
new_attrs = {k: v for k, v in new_o.__dict__.items() if not k.startswith("_")}
735+
assert all(isinstance(v, str) for v in new_attrs.values())
736+
assert len(new_attrs) == pop

0 commit comments

Comments
 (0)