@@ -455,12 +455,18 @@ class array : public buffer {
455
455
456
456
array () : array(0 , static_cast <const double *>(nullptr )) {}
457
457
458
- array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
459
- const std::vector<size_t > &strides, const void *ptr = nullptr ,
460
- handle base = handle()) {
461
- auto & api = detail::npy_api::get ();
462
- auto ndim = shape.size ();
463
- if (shape.size () != strides.size ())
458
+ using ShapeContainer = detail::any_container<Py_intptr_t>;
459
+ using StridesContainer = detail::any_container<Py_intptr_t>;
460
+
461
+ // Constructs an array taking shape/strides from arbitrary container types
462
+ array (const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides,
463
+ const void *ptr = nullptr , handle base = handle()) {
464
+
465
+ if (strides->empty ())
466
+ strides = default_strides (*shape, dt.itemsize ());
467
+
468
+ auto ndim = shape->size ();
469
+ if (ndim != strides->size ())
464
470
pybind11_fail (" NumPy: shape ndim doesn't match strides ndim" );
465
471
auto descr = dt;
466
472
@@ -474,10 +480,9 @@ class array : public buffer {
474
480
flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
475
481
}
476
482
483
+ auto &api = detail::npy_api::get ();
477
484
auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_ (
478
- api.PyArray_Type_ , descr.release ().ptr (), (int ) ndim,
479
- reinterpret_cast <Py_intptr_t *>(const_cast <size_t *>(shape.data ())),
480
- reinterpret_cast <Py_intptr_t *>(const_cast <size_t *>(strides.data ())),
485
+ api.PyArray_Type_ , descr.release ().ptr (), (int ) ndim, shape->data (), strides->data (),
481
486
const_cast <void *>(ptr), flags, nullptr ));
482
487
if (!tmp)
483
488
pybind11_fail (" NumPy: unable to create array!" );
@@ -491,27 +496,37 @@ class array : public buffer {
491
496
m_ptr = tmp.release ().ptr ();
492
497
}
493
498
494
- array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
495
- const void *ptr = nullptr , handle base = handle())
496
- : array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
499
+ template <typename ShapeIt, typename StridesIt,
500
+ typename = detail::enable_if_t <detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
501
+ array (const pybind11::dtype &dt, ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
502
+ const void *ptr = nullptr , handle base = handle())
503
+ : array(dt, {shape_first, shape_last}, {strides_first, strides_last}, ptr, base) { }
504
+
505
+ array (const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr , handle base = handle())
506
+ : array(dt, std::move(shape), {}, ptr, base) { }
497
507
498
508
array (const pybind11::dtype &dt, size_t count, const void *ptr = nullptr ,
499
509
handle base = handle())
500
- : array(dt, std::vector< size_t >{ count }, ptr, base) { }
510
+ : array(dt, ShapeContainer{{ count } }, ptr, base) { }
501
511
502
- template <typename T> array (const std::vector<size_t >& shape,
503
- const std::vector<size_t >& strides,
504
- const T* ptr, handle base = handle())
505
- : array(pybind11::dtype::of<T>(), shape, strides, (const void *) ptr, base) { }
512
+ template <typename T, typename ShapeIt, typename StridesIt,
513
+ typename = detail::enable_if_t <detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
514
+ array (ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
515
+ const T *ptr = nullptr , handle base = handle())
516
+ : array(pybind11::dtype::of<T>(), ShapeContainer(std::move(shape_first), std::move(shape_last)),
517
+ StrideContainer (std::move(strides_first), std::move(strides_last)), ptr, base) { }
506
518
507
519
template <typename T>
508
- array (const std::vector<size_t > &shape, const T *ptr,
509
- handle base = handle())
510
- : array(shape, default_strides(shape, sizeof (T)), ptr, base) { }
520
+ array (ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
521
+ : array(pybind11::dtype::of<T>(), std::move(shape), std::move(strides), ptr, base) { }
522
+
523
+ template <typename T>
524
+ array (ShapeContainer shape, const T *ptr, handle base = handle())
525
+ : array(std::move(shape), {}, ptr, base) { }
511
526
512
527
template <typename T>
513
528
array (size_t count, const T *ptr, handle base = handle())
514
- : array(std::vector< size_t >{ count }, ptr, base) { }
529
+ : array({{ count } }, ptr, base) { }
515
530
516
531
explicit array (const buffer_info &info)
517
532
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
@@ -673,9 +688,9 @@ class array : public buffer {
673
688
throw std::domain_error (" array is not writeable" );
674
689
}
675
690
676
- static std::vector<size_t > default_strides (const std::vector<size_t >& shape, size_t itemsize) {
691
+ static std::vector<Py_intptr_t > default_strides (const std::vector<Py_intptr_t >& shape, size_t itemsize) {
677
692
auto ndim = shape.size ();
678
- std::vector<size_t > strides (ndim);
693
+ std::vector<Py_intptr_t > strides (ndim);
679
694
if (ndim) {
680
695
std::fill (strides.begin (), strides.end (), itemsize);
681
696
for (size_t i = 0 ; i < ndim - 1 ; i++)
@@ -729,14 +744,18 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
729
744
730
745
explicit array_t (const buffer_info& info) : array(info) { }
731
746
732
- array_t (const std::vector<size_t > &shape,
733
- const std::vector<size_t > &strides, const T *ptr = nullptr ,
734
- handle base = handle())
735
- : array(shape, strides, ptr, base) { }
747
+ array_t (ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr , handle base = handle())
748
+ : array(std::move(shape), std::move(strides), ptr, base) { }
749
+
750
+ template <typename ShapeIt, typename StridesIt,
751
+ typename = detail::enable_if_t <detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
752
+ array_t (ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
753
+ const T *ptr = nullptr , handle base = handle())
754
+ : array(ShapeContainer(std::move(shape_first), std::move(shape_last)),
755
+ StridesContainer (std::move(strides_first), std::move(strides_last)), ptr, base) { }
736
756
737
- explicit array_t (const std::vector<size_t > &shape, const T *ptr = nullptr ,
738
- handle base = handle())
739
- : array(shape, ptr, base) { }
757
+ explicit array_t (ShapeContainer shape, const T *ptr = nullptr , handle base = handle())
758
+ : array(std::move(shape), ptr, base) { }
740
759
741
760
explicit array_t (size_t count, const T *ptr = nullptr , handle base = handle())
742
761
: array(count, ptr, base) { }
0 commit comments