Skip to content

Commit b926396

Browse files
authored
bugfix: py contains raises errors when appropiate (#4209)
* bugfix: contains now throws an exception if the key is not hashable * Fix tests and improve robustness * Remove todo * Workaround PyPy corner case * PyPy xfail * Fix typo * fix xfail * Make clang-tidy happy * Remove redundant exc checking
1 parent 5b5547b commit b926396

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

include/pybind11/pytypes.h

+10-2
Original file line numberDiff line numberDiff line change
@@ -1967,7 +1967,11 @@ class dict : public object {
19671967
void clear() /* py-non-const */ { PyDict_Clear(ptr()); }
19681968
template <typename T>
19691969
bool contains(T &&key) const {
1970-
return PyDict_Contains(m_ptr, detail::object_or_cast(std::forward<T>(key)).ptr()) == 1;
1970+
auto result = PyDict_Contains(m_ptr, detail::object_or_cast(std::forward<T>(key)).ptr());
1971+
if (result == -1) {
1972+
throw error_already_set();
1973+
}
1974+
return result == 1;
19711975
}
19721976

19731977
private:
@@ -2053,7 +2057,11 @@ class anyset : public object {
20532057
bool empty() const { return size() == 0; }
20542058
template <typename T>
20552059
bool contains(T &&val) const {
2056-
return PySet_Contains(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr()) == 1;
2060+
auto result = PySet_Contains(m_ptr, detail::object_or_cast(std::forward<T>(val)).ptr());
2061+
if (result == -1) {
2062+
throw error_already_set();
2063+
}
2064+
return result == 1;
20572065
}
20582066
};
20592067

tests/test_pytypes.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ TEST_SUBMODULE(pytypes, m) {
183183
return d2;
184184
});
185185
m.def("dict_contains",
186-
[](const py::dict &dict, py::object val) { return dict.contains(val); });
186+
[](const py::dict &dict, const py::object &val) { return dict.contains(val); });
187187
m.def("dict_contains",
188188
[](const py::dict &dict, const char *val) { return dict.contains(val); });
189189

@@ -538,6 +538,9 @@ TEST_SUBMODULE(pytypes, m) {
538538

539539
m.def("hash_function", [](py::object obj) { return py::hash(std::move(obj)); });
540540

541+
m.def("obj_contains",
542+
[](py::object &obj, const py::object &key) { return obj.contains(key); });
543+
541544
m.def("test_number_protocol", [](const py::object &a, const py::object &b) {
542545
py::list l;
543546
l.append(a.equal(b));

tests/test_pytypes.py

+25
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,31 @@ def test_dict(capture, doc):
168168
assert m.dict_keyword_constructor() == {"x": 1, "y": 2, "z": 3}
169169

170170

171+
class CustomContains:
172+
d = {"key": None}
173+
174+
def __contains__(self, m):
175+
return m in self.d
176+
177+
178+
@pytest.mark.parametrize(
179+
"arg,func",
180+
[
181+
(set(), m.anyset_contains),
182+
(dict(), m.dict_contains),
183+
(CustomContains(), m.obj_contains),
184+
],
185+
)
186+
@pytest.mark.xfail("env.PYPY and sys.pypy_version_info < (7, 3, 10)", strict=False)
187+
def test_unhashable_exceptions(arg, func):
188+
class Unhashable:
189+
__hash__ = None
190+
191+
with pytest.raises(TypeError) as exc_info:
192+
func(arg, Unhashable())
193+
assert "unhashable type:" in str(exc_info.value)
194+
195+
171196
def test_tuple():
172197
assert m.tuple_no_args() == ()
173198
assert m.tuple_ssize_t() == ()

0 commit comments

Comments
 (0)