Skip to content

Commit 503ff2a

Browse files
ncullen93ncullen93NC CullenSkylion007pre-commit-ci[bot]
authored
view for numpy arrays (#987)
* reshape * more tests * Update numpy.h * Update test_numpy_array.py * array view * test * Update test_numpy_array.cpp * Update numpy.h * Update numpy.h * Update test_numpy_array.cpp * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix merge bug * Make clang-tidy happy * Add xfail for PyPy * Fix casting issue * Fix formatting * Apply clang-tidy * Address reviews on additional tests * Fix ordering * Do a little more reordering * Fix typo * Try improving tests * Fix error in reshape * Add one more reshape test * Fix bugs and add test * Relax test * streamlining new tests; removing a few stray msg * Fix style revert * Fix clang-tidy * Misc tweaks: * Comment: matching style in file (///), responsibility sentence, consistent punctuation. * Replacing `unsigned char` with `uint8_t` for max consistency. * Removing `1` from `array_view1` because there is only one. * Partial clang-format-diff. Co-authored-by: ncullen93 <[email protected]> Co-authored-by: NC Cullen <[email protected]> Co-authored-by: Aaron Gokaslan <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ralf Grosse-Kunstleve <[email protected]>
1 parent db44afa commit 503ff2a

File tree

3 files changed

+36
-0
lines changed

3 files changed

+36
-0
lines changed

include/pybind11/numpy.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ struct npy_api {
199199
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
200200
PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int);
201201
PyObject* (*PyArray_Newshape_)(PyObject*, PyArray_Dims*, int);
202+
PyObject* (*PyArray_View_)(PyObject*, PyObject*, PyObject*);
202203

203204
private:
204205
enum functions {
@@ -216,6 +217,7 @@ struct npy_api {
216217
API_PyArray_DescrNewFromType = 96,
217218
API_PyArray_Newshape = 135,
218219
API_PyArray_Squeeze = 136,
220+
API_PyArray_View = 137,
219221
API_PyArray_DescrConverter = 174,
220222
API_PyArray_EquivTypes = 182,
221223
API_PyArray_GetArrayParamsFromObject = 278,
@@ -248,6 +250,7 @@ struct npy_api {
248250
DECL_NPY_API(PyArray_DescrNewFromType);
249251
DECL_NPY_API(PyArray_Newshape);
250252
DECL_NPY_API(PyArray_Squeeze);
253+
DECL_NPY_API(PyArray_View);
251254
DECL_NPY_API(PyArray_DescrConverter);
252255
DECL_NPY_API(PyArray_EquivTypes);
253256
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
@@ -802,6 +805,21 @@ class array : public buffer {
802805
return new_array;
803806
}
804807

808+
/// Create a view of an array in a different data type.
809+
/// This function may fundamentally reinterpret the data in the array.
810+
/// It is the responsibility of the caller to ensure that this is safe.
811+
/// Only supports the `dtype` argument, the `type` argument is omitted,
812+
/// to be added as needed.
813+
array view(const std::string &dtype) {
814+
auto &api = detail::npy_api::get();
815+
auto new_view = reinterpret_steal<array>(api.PyArray_View_(
816+
m_ptr, dtype::from_args(pybind11::str(dtype)).release().ptr(), nullptr));
817+
if (!new_view) {
818+
throw error_already_set();
819+
}
820+
return new_view;
821+
}
822+
805823
/// Ensure that the argument is a NumPy array
806824
/// In case of an error, nullptr is returned and the Python error is cleared.
807825
static array ensure(handle h, int ExtraFlags = 0) {

tests/test_numpy_array.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,9 @@ TEST_SUBMODULE(numpy_array, sm) {
405405
return a;
406406
});
407407

408+
sm.def("array_view",
409+
[](py::array_t<uint8_t> a, const std::string &dtype) { return a.view(dtype); });
410+
408411
sm.def("reshape_initializer_list", [](py::array_t<int> a, size_t N, size_t M, size_t O) {
409412
return a.reshape({N, M, O});
410413
});

tests/test_numpy_array.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,21 @@ def test_array_create_and_resize():
476476
assert np.all(a == 42.0)
477477

478478

479+
def test_array_view():
480+
a = np.ones(100 * 4).astype("uint8")
481+
a_float_view = m.array_view(a, "float32")
482+
assert a_float_view.shape == (100 * 1,) # 1 / 4 bytes = 8 / 32
483+
484+
a_int16_view = m.array_view(a, "int16") # 1 / 2 bytes = 16 / 32
485+
assert a_int16_view.shape == (100 * 2,)
486+
487+
488+
def test_array_view_invalid():
489+
a = np.ones(100 * 4).astype("uint8")
490+
with pytest.raises(TypeError):
491+
m.array_view(a, "deadly_dtype")
492+
493+
479494
def test_reshape_initializer_list():
480495
a = np.arange(2 * 7 * 3) + 1
481496
x = m.reshape_initializer_list(a, 2, 7, 3)

0 commit comments

Comments
 (0)