Skip to content

Fix undefined memoryview format #2223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 15, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions include/pybind11/buffer_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ struct buffer_info {
explicit buffer_info(Py_buffer *view, bool ownview = true)
: buffer_info(view->buf, view->itemsize, view->format, view->ndim,
{view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}, view->readonly) {
this->view = view;
this->m_view = view;
this->ownview = ownview;
}

Expand All @@ -73,24 +73,26 @@ struct buffer_info {
ndim = rhs.ndim;
shape = std::move(rhs.shape);
strides = std::move(rhs.strides);
std::swap(view, rhs.view);
std::swap(m_view, rhs.m_view);
std::swap(ownview, rhs.ownview);
readonly = rhs.readonly;
return *this;
}

~buffer_info() {
if (view && ownview) { PyBuffer_Release(view); delete view; }
if (m_view && ownview) { PyBuffer_Release(m_view); delete m_view; }
}

Py_buffer *view() const { return m_view; }
Py_buffer *&view() { return m_view; }
private:
struct private_ctr_tag { };

buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
detail::any_container<ssize_t> &&shape_in, detail::any_container<ssize_t> &&strides_in, bool readonly)
: buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in), readonly) { }

Py_buffer *view = nullptr;
Py_buffer *m_view = nullptr;
bool ownview = false;
};

Expand Down
82 changes: 55 additions & 27 deletions include/pybind11/pytypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "buffer_info.h"
#include <utility>
#include <type_traits>
#include <algorithm>

NAMESPACE_BEGIN(PYBIND11_NAMESPACE)

Expand Down Expand Up @@ -1331,35 +1332,62 @@ class buffer : public object {

class memoryview : public object {
public:
explicit memoryview(const buffer_info& info) {
static Py_buffer buf { };
// Py_buffer uses signed sizes, strides and shape!..
static std::vector<Py_ssize_t> py_strides { };
static std::vector<Py_ssize_t> py_shape { };
buf.buf = info.ptr;
buf.itemsize = info.itemsize;
buf.format = const_cast<char *>(info.format.c_str());
buf.ndim = (int) info.ndim;
buf.len = info.size;
py_strides.clear();
py_shape.clear();
for (size_t i = 0; i < (size_t) info.ndim; ++i) {
py_strides.push_back(info.strides[i]);
py_shape.push_back(info.shape[i]);
}
buf.strides = py_strides.data();
buf.shape = py_shape.data();
buf.suboffsets = nullptr;
buf.readonly = info.readonly;
buf.internal = nullptr;

m_ptr = PyMemoryView_FromBuffer(&buf);
if (!m_ptr)
pybind11_fail("Unable to create memoryview from buffer descriptor");
}

PYBIND11_OBJECT_CVT(memoryview, object, PyMemoryView_Check, PyMemoryView_FromObject)
#if PY_MAJOR_VERSION >= 3
explicit memoryview(char *mem, ssize_t size, bool writable = false)
: object(PyMemoryView_FromMemory(mem, size, (writable) ? PyBUF_WRITE : PyBUF_READ), stolen_t{}) {
if (!m_ptr) pybind11_fail("Could not allocate memoryview object!");
}
#endif
explicit memoryview(const buffer_info& info);
};

inline memoryview::memoryview(const buffer_info& info) {
// TODO: two-letter formats are not supported.
static const char* formats[] = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Observations, I have to defer to @wjakob judgement for what to do (if anything):

Here the compiler has to work very hard just to put together the array of char pointers:
static const char* formats[] = {"?", "b", "B", "h", "H", "i", "I", "q", "Q", "f", "d", "g"};
That's basically the same as the hard-coded string in detail/common.h:
static constexpr const char c = "?bBhHiIqQfdg"[...];

It's only a subset of the format strings accepted by the current Python implementation:

https://github.com/python/cpython/blob/e51dd9dad6590bf3a940723fbbaaf4f64a3c9228/Objects/memoryobject.c#L1150

Very unfortunately, this authoritative Python function is hidden away (static).
My own choice would be to copy the function into pybind11 (probably with modifications), with a long comment to explain why. I'd also carefully check every released Python version to see if the list evolved over time, to track that in our copy, if necessary.

Alternatively, if we decide we only want the subset of format strings, I'd do a slight refactor of detail/common.h to make the list available without the very involved detour through the C++ template engine.

pybind11::format_descriptor<bool>::value,
pybind11::format_descriptor<int8_t>::value,
pybind11::format_descriptor<uint8_t>::value,
pybind11::format_descriptor<int16_t>::value,
pybind11::format_descriptor<uint16_t>::value,
pybind11::format_descriptor<int32_t>::value,
pybind11::format_descriptor<uint32_t>::value,
pybind11::format_descriptor<int64_t>::value,
pybind11::format_descriptor<uint64_t>::value,
pybind11::format_descriptor<float>::value,
pybind11::format_descriptor<double>::value,
pybind11::format_descriptor<long double>::value,
};
if (info.view()) {
// Note: PyMemoryView_FromBuffer never increments obj reference.
m_ptr = (info.view()->obj) ?
PyMemoryView_FromObject(info.view()->obj) :
PyMemoryView_FromBuffer(info.view());
}
else {
size_t length = sizeof(formats) / sizeof(char*);
auto format = std::find(formats, formats + length, info.format);
if (format == (formats + length))
pybind11_fail("Invalid format string");
std::vector<Py_ssize_t> shape(info.shape.begin(), info.shape.end());
std::vector<Py_ssize_t> strides(info.strides.begin(), info.strides.end());
Py_buffer view;
view.buf = info.ptr;
view.obj = nullptr;
view.len = info.size * info.itemsize;
view.readonly = info.readonly;
view.itemsize = info.itemsize;
view.format = const_cast<char*>(*format);
view.ndim = static_cast<int>(info.ndim);
view.shape = shape.data();
view.strides = strides.data();
view.suboffsets = nullptr;
view.internal = nullptr;
m_ptr = PyMemoryView_FromBuffer(&view);
}
if (!m_ptr)
pybind11_fail("Unable to create memoryview from buffer descriptor");
}
/// @} pytypes

/// \addtogroup python_builtins
Expand Down
17 changes: 17 additions & 0 deletions tests/test_pytypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,4 +307,21 @@ TEST_SUBMODULE(pytypes, m) {
m.def("test_list_slicing", [](py::list a) {
return a[py::slice(0, -1, 2)];
});

m.def("test_memoryview_fromobject", [](py::buffer b) {
return py::memoryview(b);
});

m.def("test_memoryview_frombuffer_reference", [](py::buffer b) {
return py::memoryview(b.request());
});

m.def("test_memoryview_frombuffer_new", []() {
const char* buf = "abc";
const char* buf2 = "\x00\x00\x00\x00";
auto mv = py::memoryview(py::buffer_info(buf, 3, 1));
// Call twice with a different buffer to check the view content.
py::memoryview(py::buffer_info(const_cast<char*>(buf2), 4, "i", 1));
return mv;
});
}
21 changes: 21 additions & 0 deletions tests/test_pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,24 @@ def test_number_protocol():
def test_list_slicing():
li = list(range(100))
assert li[::2] == m.test_list_slicing(li)


@pytest.mark.parametrize('method, args, format', [
(m.test_memoryview_fromobject, (b'abc', ), 'B'),
(m.test_memoryview_frombuffer_reference, (b'abc',), 'B'),
(m.test_memoryview_frombuffer_new, tuple(), 'b'),
])
def test_memoryview(method, args, format):
view = method(*args)
assert isinstance(view, memoryview)
assert view.format == format
assert view[0] == ord(b'a') if sys.version_info[0] == 3 else 'a'
assert len(view) == 3


def test_memoryview_refcount():
buf = b'\x00\x00\x00\x00'
ref_before = sys.getrefcount(buf)
view = m.test_memoryview_frombuffer_reference(buf)
ref_after = sys.getrefcount(buf)
assert ref_before < ref_after