Skip to content

Commit 333e889

Browse files
author
Wenzel Jakob
committed
Improved STL support, support for std::set
1 parent 723bc65 commit 333e889

File tree

6 files changed

+143
-20
lines changed

6 files changed

+143
-20
lines changed

example/example2.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,29 @@ class Example2 {
2727
return dict;
2828
}
2929

30+
/* Create and return a Python set */
31+
py::set get_set() {
32+
py::set set;
33+
set.insert(py::str("key1"));
34+
set.insert(py::str("key2"));
35+
return set;
36+
}
37+
3038
/* Create and return a C++ dictionary */
3139
std::map<std::string, std::string> get_dict_2() {
3240
std::map<std::string, std::string> result;
3341
result["key"] = "value";
3442
return result;
3543
}
3644

45+
/* Create and return a C++ set */
46+
std::set<std::string> get_set_2() {
47+
std::set<std::string> result;
48+
result.insert("key1");
49+
result.insert("key2");
50+
return result;
51+
}
52+
3753
/* Create, manipulate, and return a Python list */
3854
py::list get_list() {
3955
py::list list;
@@ -62,6 +78,18 @@ class Example2 {
6278
std::cout << "key: " << item.first << ", value=" << item.second << std::endl;
6379
}
6480

81+
/* Easily iterate over a setionary using a C++11 range-based for loop */
82+
void print_set(py::set set) {
83+
for (auto item : set)
84+
std::cout << "key: " << item << std::endl;
85+
}
86+
87+
/* STL data types are automatically casted from Python */
88+
void print_set_2(const std::set<std::string> &set) {
89+
for (auto item : set)
90+
std::cout << "key: " << item << std::endl;
91+
}
92+
6593
/* Easily iterate over a list using a C++11 range-based for loop */
6694
void print_list(py::list list) {
6795
int index = 0;
@@ -105,8 +133,12 @@ void init_ex2(py::module &m) {
105133
.def("get_dict_2", &Example2::get_dict_2, "Return a C++ dictionary")
106134
.def("get_list", &Example2::get_list, "Return a Python list")
107135
.def("get_list_2", &Example2::get_list_2, "Return a C++ list")
136+
.def("get_set", &Example2::get_set, "Return a Python set")
137+
.def("get_set2", &Example2::get_set, "Return a C++ set")
108138
.def("print_dict", &Example2::print_dict, "Print entries of a Python dictionary")
109139
.def("print_dict_2", &Example2::print_dict_2, "Print entries of a C++ dictionary")
140+
.def("print_set", &Example2::print_set, "Print entries of a Python set")
141+
.def("print_set_2", &Example2::print_set_2, "Print entries of a C++ set")
110142
.def("print_list", &Example2::print_list, "Print entries of a Python list")
111143
.def("print_list_2", &Example2::print_list_2, "Print entries of a C++ list")
112144
.def("pair_passthrough", &Example2::pair_passthrough, "Return a pair in reversed order")

example/example2.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@
2929
dict_result['key2'] = 'value2'
3030
instance.print_dict_2(dict_result)
3131

32+
set_result = instance.get_set()
33+
set_result.add(u'key3')
34+
instance.print_set(set_result)
35+
36+
set_result = instance.get_set2()
37+
set_result.add(u'key3')
38+
instance.print_set_2(set_result)
39+
3240
list_result = instance.get_list()
3341
list_result.append('value2')
3442
instance.print_list(list_result)

example/example2.ref

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ key: key2, value=value2
66
key: key, value=value
77
key: key, value=value
88
key: key2, value=value2
9+
key: key3
10+
key: key2
11+
key: key1
12+
key: key1
13+
key: key2
14+
key: key3
915
Entry at positon 0: value
1016
list item 0: overwritten
1117
list item 1: value2
@@ -44,6 +50,16 @@ class EExxaammppllee22(__builtin__.object)
4450
|
4551
| Return a C++ list
4652
|
53+
| ggeett__sseett(...)
54+
| Signature : (Example2) -> set
55+
|
56+
| Return a Python set
57+
|
58+
| ggeett__sseett22(...)
59+
| Signature : (Example2) -> set
60+
|
61+
| Return a C++ set
62+
|
4763
| ppaaiirr__ppaasssstthhrroouugghh(...)
4864
| Signature : (Example2, (bool, str)) -> (str, bool)
4965
|
@@ -69,6 +85,16 @@ class EExxaammppllee22(__builtin__.object)
6985
|
7086
| Print entries of a C++ list
7187
|
88+
| pprriinntt__sseett(...)
89+
| Signature : (Example2, set) -> None
90+
|
91+
| Print entries of a Python set
92+
|
93+
| pprriinntt__sseett__22(...)
94+
| Signature : (Example2, set<str>) -> None
95+
|
96+
| Print entries of a C++ set
97+
|
7298
| tthhrrooww__eexxcceeppttiioonn(...)
7399
| Signature : (Example2) -> None
74100
|
@@ -85,7 +111,7 @@ class EExxaammppllee22(__builtin__.object)
85111
| ____nneeww____ = <built-in method __new__ of Example2_meta object>
86112
| T.__new__(S, ...) -> a new object with type S, a subtype of T
87113
|
88-
| ____ppyybbiinndd____ = <capsule object NULL>
114+
| ____ppyybbiinndd1111____ = <capsule object NULL>
89115
|
90116
| nneeww__iinnssttaannccee = <built-in method new_instance of PyCapsule object>
91117
| Signature : () -> Example2

include/pybind11/cast.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@ PYBIND11_TYPE_CASTER_PYTYPE(capsule) PYBIND11_TYPE_CASTER_PYTYPE(dict)
571571
PYBIND11_TYPE_CASTER_PYTYPE(float_) PYBIND11_TYPE_CASTER_PYTYPE(int_)
572572
PYBIND11_TYPE_CASTER_PYTYPE(list) PYBIND11_TYPE_CASTER_PYTYPE(slice)
573573
PYBIND11_TYPE_CASTER_PYTYPE(tuple) PYBIND11_TYPE_CASTER_PYTYPE(function)
574+
PYBIND11_TYPE_CASTER_PYTYPE(set) PYBIND11_TYPE_CASTER_PYTYPE(iterator)
574575

575576
NAMESPACE_END(detail)
576577

include/pybind11/pytypes.h

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class object;
1919
class str;
2020
class object;
2121
class dict;
22+
class iterator;
2223
namespace detail { class accessor; }
2324

2425
/// Holds a reference to a Python object (no reference counting)
@@ -33,6 +34,8 @@ class handle {
3334
void dec_ref() const { Py_XDECREF(m_ptr); }
3435
int ref_count() const { return (int) Py_REFCNT(m_ptr); }
3536
handle get_type() { return (PyObject *) Py_TYPE(m_ptr); }
37+
inline iterator begin();
38+
inline iterator end();
3639
inline detail::accessor operator[](handle key);
3740
inline detail::accessor operator[](const char *key);
3841
inline detail::accessor attr(handle key);
@@ -73,6 +76,23 @@ class object : public handle {
7376
}
7477
};
7578

79+
class iterator : public object {
80+
public:
81+
iterator(PyObject *obj, bool borrowed = false) : object(obj, borrowed) { ++*this; }
82+
iterator& operator++() {
83+
if (ptr())
84+
value = object(PyIter_Next(ptr()), false);
85+
return *this;
86+
}
87+
bool operator==(const iterator &it) const { return *it == **this; }
88+
bool operator!=(const iterator &it) const { return *it != **this; }
89+
object operator*() { return value; }
90+
const object &operator*() const { return value; }
91+
bool check() const { return PyIter_Check(ptr()); }
92+
private:
93+
object value;
94+
};
95+
7696
NAMESPACE_BEGIN(detail)
7797
class accessor {
7898
public:
@@ -159,18 +179,6 @@ struct tuple_accessor {
159179
size_t index;
160180
};
161181

162-
class list_iterator {
163-
public:
164-
list_iterator(PyObject *list, ssize_t pos) : list(list), pos(pos) { }
165-
list_iterator& operator++() { ++pos; return *this; }
166-
object operator*() { return object(PyList_GetItem(list, pos), true); }
167-
bool operator==(const list_iterator &it) const { return it.pos == pos; }
168-
bool operator!=(const list_iterator &it) const { return it.pos != pos; }
169-
private:
170-
PyObject *list;
171-
ssize_t pos;
172-
};
173-
174182
struct dict_iterator {
175183
public:
176184
dict_iterator(PyObject *dict = nullptr, ssize_t pos = -1) : dict(dict), pos(pos) { }
@@ -194,6 +202,8 @@ inline detail::accessor handle::operator[](handle key) { return detail::accessor
194202
inline detail::accessor handle::operator[](const char *key) { return detail::accessor(ptr(), key, false); }
195203
inline detail::accessor handle::attr(handle key) { return detail::accessor(ptr(), key.ptr(), true); }
196204
inline detail::accessor handle::attr(const char *key) { return detail::accessor(ptr(), key, true); }
205+
inline iterator handle::begin() { return iterator(PyObject_GetIter(ptr())); }
206+
inline iterator handle::end() { return iterator(nullptr); }
197207

198208
#define PYBIND11_OBJECT_CVT(Name, Parent, CheckFun, CvtStmt) \
199209
Name(const handle &h, bool borrowed) : Parent(h, borrowed) { CvtStmt; } \
@@ -310,19 +320,27 @@ class dict : public object {
310320
size_t size() const { return (size_t) PyDict_Size(m_ptr); }
311321
detail::dict_iterator begin() { return (++detail::dict_iterator(ptr(), 0)); }
312322
detail::dict_iterator end() { return detail::dict_iterator(); }
323+
void clear() { PyDict_Clear(ptr()); }
313324
};
314325

315326
class list : public object {
316327
public:
317328
PYBIND11_OBJECT(list, object, PyList_Check)
318329
list(size_t size = 0) : object(PyList_New((ssize_t) size), false) { }
319330
size_t size() const { return (size_t) PyList_Size(m_ptr); }
320-
detail::list_iterator begin() { return detail::list_iterator(ptr(), 0); }
321-
detail::list_iterator end() { return detail::list_iterator(ptr(), (ssize_t) size()); }
322331
detail::list_accessor operator[](size_t index) { return detail::list_accessor(ptr(), index); }
323332
void append(const object &object) { PyList_Append(m_ptr, (PyObject *) object.ptr()); }
324333
};
325334

335+
class set : public object {
336+
public:
337+
PYBIND11_OBJECT(set, object, PySet_Check)
338+
set() : object(PySet_New(nullptr), false) { }
339+
size_t size() const { return (size_t) PySet_Size(m_ptr); }
340+
void insert(const object &object) { PySet_Add(m_ptr, (PyObject *) object.ptr()); }
341+
void clear() { PySet_Clear(ptr()); }
342+
};
343+
326344
class function : public object {
327345
public:
328346
PYBIND11_OBJECT_DEFAULT(function, object, PyFunction_Check)

include/pybind11/stl.h

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "pybind11.h"
1313
#include <map>
14+
#include <set>
1415
#include <iostream>
1516

1617

@@ -22,8 +23,8 @@
2223
NAMESPACE_BEGIN(pybind11)
2324
NAMESPACE_BEGIN(detail)
2425

25-
template <typename Value> struct type_caster<std::vector<Value>> {
26-
typedef std::vector<Value> type;
26+
template <typename Value, typename Alloc> struct type_caster<std::vector<Value, Alloc>> {
27+
typedef std::vector<Value, Alloc> type;
2728
typedef type_caster<Value> value_conv;
2829
public:
2930
bool load(PyObject *src, bool convert) {
@@ -32,8 +33,8 @@ template <typename Value> struct type_caster<std::vector<Value>> {
3233
size_t size = (size_t) PyList_GET_SIZE(src);
3334
value.reserve(size);
3435
value.clear();
36+
value_conv conv;
3537
for (size_t i=0; i<size; ++i) {
36-
value_conv conv;
3738
if (!conv.load(PyList_GetItem(src, (ssize_t) i), convert))
3839
return false;
3940
value.push_back((Value) conv);
@@ -57,9 +58,46 @@ template <typename Value> struct type_caster<std::vector<Value>> {
5758
PYBIND11_TYPE_CASTER(type, detail::descr("list<") + value_conv::name() + detail::descr(">"));
5859
};
5960

60-
template <typename Key, typename Value> struct type_caster<std::map<Key, Value>> {
61+
template <typename Value, typename Compare, typename Alloc> struct type_caster<std::set<Value, Compare, Alloc>> {
62+
typedef std::set<Value, Compare, Alloc> type;
63+
typedef type_caster<Value> value_conv;
64+
public:
65+
bool load(PyObject *src, bool convert) {
66+
pybind11::set s(src, true);
67+
if (!s.check())
68+
return false;
69+
value.clear();
70+
value_conv conv;
71+
for (const object &o: s) {
72+
if (!conv.load((PyObject *) o.ptr(), convert))
73+
return false;
74+
value.insert((Value) conv);
75+
}
76+
return true;
77+
}
78+
79+
static PyObject *cast(const type &src, return_value_policy policy, PyObject *parent) {
80+
PyObject *set = PySet_New(nullptr);
81+
for (auto const &value: src) {
82+
PyObject *value_ = value_conv::cast(value, policy, parent);
83+
if (!value_) {
84+
Py_DECREF(set);
85+
return nullptr;
86+
}
87+
if (PySet_Add(set, value) != 0) {
88+
Py_DECREF(value);
89+
Py_DECREF(set);
90+
return nullptr;
91+
}
92+
}
93+
return set;
94+
}
95+
PYBIND11_TYPE_CASTER(type, detail::descr("set<") + value_conv::name() + detail::descr(">"));
96+
};
97+
98+
template <typename Key, typename Value, typename Compare, typename Alloc> struct type_caster<std::map<Key, Value, Compare, Alloc>> {
6199
public:
62-
typedef std::map<Key, Value> type;
100+
typedef std::map<Key, Value, Compare, Alloc> type;
63101
typedef type_caster<Key> key_conv;
64102
typedef type_caster<Value> value_conv;
65103

0 commit comments

Comments
 (0)