Skip to content

Commit a4b5628

Browse files
committed
Support frozenset, tuple as dict keys
#3836 Add frozenset as a pybind11 type. Add freeze() function converting set to frozenset and list to tuple; use it in std::set and std::map casters. Add tests.
1 parent e8e229f commit a4b5628

File tree

5 files changed

+106
-15
lines changed

5 files changed

+106
-15
lines changed

include/pybind11/cast.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1617,6 +1617,25 @@ object object_api<Derived>::call(Args &&...args) const {
16171617
return operator()<policy>(std::forward<Args>(args)...);
16181618
}
16191619

1620+
// Convert list -> tuple and set -> frozenset for use as keys in dict, set etc.
1621+
// https://mail.python.org/pipermail/python-dev/2005-October/057586.html
1622+
inline object freeze(object &&obj) {
1623+
if (isinstance<list>(obj)) {
1624+
return tuple(std::move(obj));
1625+
} else if (isinstance<set>(obj)) {
1626+
return frozenset(std::move(obj));
1627+
} else {
1628+
return std::move(obj);
1629+
}
1630+
}
1631+
1632+
template <typename Caster, typename SFINAE = decltype(Caster::frozen_name)>
1633+
constexpr inline auto get_frozen_name_impl(int) { return Caster::frozen_name; }
1634+
template <typename Caster>
1635+
constexpr inline auto get_frozen_name_impl(long) { return Caster::name; }
1636+
template <typename Caster>
1637+
constexpr inline auto get_frozen_name() { return get_frozen_name_impl<Caster>(0); }
1638+
16201639
PYBIND11_NAMESPACE_END(detail)
16211640

16221641
template <typename T>

include/pybind11/pytypes.h

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1784,24 +1784,41 @@ class kwargs : public dict {
17841784
PYBIND11_OBJECT_DEFAULT(kwargs, dict, PyDict_Check)
17851785
};
17861786

1787-
class set : public object {
1787+
class set_base : public object {
1788+
protected:
1789+
PYBIND11_OBJECT(set_base, object, PyAnySet_Check)
1790+
1791+
public:
1792+
size_t size() const { return (size_t) PySet_Size(m_ptr); }
1793+
bool empty() const { return size() == 0; }
1794+
template <typename T>
1795+
bool contains(T &&val) const {
1796+
return PySet_Contains(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()) == 1;
1797+
}
1798+
};
1799+
1800+
class set : public set_base {
17881801
public:
1789-
PYBIND11_OBJECT_CVT(set, object, PySet_Check, PySet_New)
1790-
set() : object(PySet_New(nullptr), stolen_t{}) {
1802+
PYBIND11_OBJECT_CVT(set, set_base, PySet_Check, PySet_New)
1803+
set() : set_base(PySet_New(nullptr), stolen_t{}) {
17911804
if (!m_ptr) {
17921805
pybind11_fail("Could not allocate set object!");
17931806
}
17941807
}
1795-
size_t size() const { return (size_t) PySet_Size(m_ptr); }
1796-
bool empty() const { return size() == 0; }
17971808
template <typename T>
17981809
bool add(T &&val) /* py-non-const */ {
17991810
return PySet_Add(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()) == 0;
18001811
}
18011812
void clear() /* py-non-const */ { PySet_Clear(m_ptr); }
1802-
template <typename T>
1803-
bool contains(T &&val) const {
1804-
return PySet_Contains(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()) == 1;
1813+
};
1814+
1815+
class frozenset : public set_base {
1816+
public:
1817+
PYBIND11_OBJECT_CVT(frozenset, set_base, PyFrozenSet_Check, PyFrozenSet_New)
1818+
frozenset() : set_base(PyFrozenSet_New(nullptr), stolen_t{}) {
1819+
if (!m_ptr) {
1820+
pybind11_fail("Could not allocate frozenset object!");
1821+
}
18051822
}
18061823
};
18071824

include/pybind11/stl.h

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ struct set_caster {
5555
using key_conv = make_caster<Key>;
5656

5757
bool load(handle src, bool convert) {
58-
if (!isinstance<pybind11::set>(src)) {
58+
if (!isinstance<set_base>(src)) {
5959
return false;
6060
}
61-
auto s = reinterpret_borrow<pybind11::set>(src);
61+
auto s = reinterpret_borrow<set_base>(src);
6262
value.clear();
6363
for (auto entry : s) {
6464
key_conv conv;
@@ -79,14 +79,15 @@ struct set_caster {
7979
for (auto &&value : src) {
8080
auto value_ = reinterpret_steal<object>(
8181
key_conv::cast(forward_like<T>(value), policy, parent));
82-
if (!value_ || !s.add(std::move(value_))) {
82+
if (!value_ || !s.add(freeze(std::move(value_)))) {
8383
return handle();
8484
}
8585
}
8686
return s.release();
8787
}
8888

89-
PYBIND11_TYPE_CASTER(type, const_name("Set[") + key_conv::name + const_name("]"));
89+
PYBIND11_TYPE_CASTER(type, const_name("Set[") + get_frozen_name<key_conv>() + const_name("]"));
90+
static constexpr auto frozen_name = const_name("FrozenSet[") + get_frozen_name<key_conv>() + const_name("]");
9091
};
9192

9293
template <typename Type, typename Key, typename Value>
@@ -128,14 +129,14 @@ struct map_caster {
128129
if (!key || !value) {
129130
return handle();
130131
}
131-
d[key] = value;
132+
d[freeze(std::move(key))] = std::move(value);
132133
}
133134
return d.release();
134135
}
135136

136137
PYBIND11_TYPE_CASTER(Type,
137-
const_name("Dict[") + key_conv::name + const_name(", ") + value_conv::name
138-
+ const_name("]"));
138+
const_name("Dict[") + get_frozen_name<key_conv>() + const_name(", ")
139+
+ value_conv::name + const_name("]"));
139140
};
140141

141142
template <typename Type, typename Value>
@@ -188,6 +189,7 @@ struct list_caster {
188189
}
189190

190191
PYBIND11_TYPE_CASTER(Type, const_name("List[") + value_conv::name + const_name("]"));
192+
static constexpr auto frozen_name = const_name("Tuple[") + value_conv::name + const_name(", ...]");
191193
};
192194

193195
template <typename Type, typename Alloc>
@@ -257,6 +259,11 @@ struct array_caster {
257259
const_name("[") + const_name<Size>()
258260
+ const_name("]"))
259261
+ const_name("]"));
262+
static constexpr auto frozen_name = const_name("Tuple[") + value_conv::name
263+
+ const_name<Resizable>(const_name(", ..."),
264+
const_name("[") + const_name<Size>()
265+
+ const_name("]"))
266+
+ const_name("]");
260267
};
261268

262269
template <typename Type, size_t Size>

tests/test_stl.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,22 @@ TEST_SUBMODULE(stl, m) {
248248
return v;
249249
});
250250

251+
// test_frozen_key
252+
m.def("cast_set_map", []() {
253+
return std::map<std::set<std::string>, std::string>{{{"key1", "key2"}, "value"}};
254+
});
255+
m.def("load_set_map", [](const std::map<std::set<std::string>, std::string> &map) {
256+
return map.at({"key1", "key2"}) == "value" && map.at({"key3"}) == "value2";
257+
});
258+
m.def("cast_set_set", []() { return std::set<std::set<std::string>>{{"key1", "key2"}}; });
259+
m.def("load_set_set", [](const std::set<std::set<std::string>> &set) {
260+
return (set.count({"key1", "key2"}) != 0u) && (set.count({"key3"}) != 0u);
261+
});
262+
m.def("cast_vector_set", []() { return std::set<std::vector<int>>{{1, 2}}; });
263+
m.def("load_vector_set", [](const std::set<std::vector<int>> &set) {
264+
return (set.count({1, 2}) != 0u) && (set.count({3}) != 0u);
265+
});
266+
251267
pybind11::enum_<EnumType>(m, "EnumType")
252268
.value("kSet", EnumType::kSet)
253269
.value("kUnset", EnumType::kUnset);

tests/test_stl.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,38 @@ def test_recursive_casting():
9494
assert z[0].value == 7 and z[1].value == 42
9595

9696

97+
def test_frozen_key(doc):
98+
"""Test that we special-case C++ key types to Python immutable containers, e.g.:
99+
std::map<std::set<K>, V> <-> dict[frozenset[K], V]
100+
std::set<std::set<T>> <-> set[frozenset[T]]
101+
std::set<std::vector<T>> <-> set[tuple[T, ...]]
102+
"""
103+
s = m.cast_set_map()
104+
assert s == {frozenset({"key1", "key2"}): "value"}
105+
s[frozenset({"key3"})] = "value2"
106+
assert m.load_set_map(s)
107+
assert doc(m.cast_set_map) == "cast_set_map() -> Dict[FrozenSet[str], str]"
108+
assert (
109+
doc(m.load_set_map) == "load_set_map(arg0: Dict[FrozenSet[str], str]) -> bool"
110+
)
111+
112+
s = m.cast_set_set()
113+
assert s == {frozenset({"key1", "key2"})}
114+
s.add(frozenset({"key3"}))
115+
assert m.load_set_set(s)
116+
assert doc(m.cast_set_set) == "cast_set_set() -> Set[FrozenSet[str]]"
117+
assert doc(m.load_set_set) == "load_set_set(arg0: Set[FrozenSet[str]]) -> bool"
118+
119+
s = m.cast_vector_set()
120+
assert s == {(1, 2)}
121+
s.add((3,))
122+
assert m.load_vector_set(s)
123+
assert doc(m.cast_vector_set) == "cast_vector_set() -> Set[Tuple[int, ...]]"
124+
assert (
125+
doc(m.load_vector_set) == "load_vector_set(arg0: Set[Tuple[int, ...]]) -> bool"
126+
)
127+
128+
97129
def test_move_out_container():
98130
"""Properties use the `reference_internal` policy by default. If the underlying function
99131
returns an rvalue, the policy is automatically changed to `move` to avoid referencing

0 commit comments

Comments
 (0)