@@ -455,12 +455,15 @@ class array : public buffer {
455
455
456
456
array () : array(0 , static_cast <const double *>(nullptr )) {}
457
457
458
- array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
459
- const std::vector<size_t > &strides, const void *ptr = nullptr ,
460
- handle base = handle()) {
458
+ template <typename ShapeIt, typename StridesIt,
459
+ typename = detail::enable_if_t <detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
460
+ array (const pybind11::dtype &dt, ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
461
+ const void *ptr = nullptr , handle base = handle()) {
461
462
auto & api = detail::npy_api::get ();
463
+
464
+ std::vector<Py_intptr_t> shape (shape_first, shape_last), strides (strides_first, strides_last);
462
465
auto ndim = shape.size ();
463
- if (shape. size () != strides.size ())
466
+ if (ndim != strides.size ())
464
467
pybind11_fail (" NumPy: shape ndim doesn't match strides ndim" );
465
468
auto descr = dt;
466
469
@@ -475,9 +478,7 @@ class array : public buffer {
475
478
}
476
479
477
480
auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_ (
478
- api.PyArray_Type_ , descr.release ().ptr (), (int ) ndim,
479
- reinterpret_cast <Py_intptr_t *>(const_cast <size_t *>(shape.data ())),
480
- reinterpret_cast <Py_intptr_t *>(const_cast <size_t *>(strides.data ())),
481
+ api.PyArray_Type_ , descr.release ().ptr (), (int ) ndim, shape.data (), strides.data (),
481
482
const_cast <void *>(ptr), flags, nullptr ));
482
483
if (!tmp)
483
484
pybind11_fail (" NumPy: unable to create array!" );
@@ -491,27 +492,59 @@ class array : public buffer {
491
492
m_ptr = tmp.release ().ptr ();
492
493
}
493
494
494
- array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
495
+ template <typename Shape, typename Strides,
496
+ typename = decltype(std::begin(std::declval<const Shape &>())),
497
+ typename = decltype(std::begin(std::declval<const Strides &>()))>
498
+ array (const pybind11::dtype &dt, const Shape &shape, const Strides &strides,
495
499
const void *ptr = nullptr , handle base = handle())
496
- : array(dt, shape, default_strides (shape, dt.itemsize() ), ptr, base) { }
500
+ : array(dt, std::begin( shape), std::end (shape), std::begin(strides), std::end(strides ), ptr, base) { }
497
501
498
- array (const pybind11::dtype &dt, size_t count, const void *ptr = nullptr ,
499
- handle base = handle())
500
- : array(dt, std::vector<size_t >{ count }, ptr, base) { }
502
+ template <typename TSh, typename TSt>
503
+ array (const pybind11::dtype &dt, const std::initializer_list<TSh> &shape, const std::initializer_list<TSt> &strides,
504
+ const void *ptr = nullptr , handle base = handle())
505
+ : array(dt, std::begin(shape), std::end(shape), std::begin(strides), std::end(strides), ptr, base) { }
501
506
502
- template <typename T> array (const std::vector<size_t >& shape,
503
- const std::vector<size_t >& strides,
504
- const T* ptr, handle base = handle())
505
- : array(pybind11::dtype::of<T>(), shape, strides, (const void *) ptr, base) { }
507
+ template <typename Shape, typename = decltype(std::begin(std::declval<const Shape &>()))>
508
+ array (const pybind11::dtype &dt, const Shape &shape, const void *ptr = nullptr , handle base = handle())
509
+ : array(dt, shape, default_strides({std::begin (shape), std::end (shape)}, dt.itemsize()), ptr, base) { }
506
510
507
- template <typename T>
508
- array (const std::vector<size_t > &shape, const T *ptr,
511
+ template <typename TSh>
512
+ array (const pybind11::dtype &dt, const std::initializer_list<TSh> &shape, const void *ptr = nullptr , handle base = handle())
513
+ : array(dt, shape, default_strides({std::begin (shape), std::end (shape)}, dt.itemsize()), ptr, base) { }
514
+
515
+ array (const pybind11::dtype &dt, size_t count, const void *ptr = nullptr ,
509
516
handle base = handle())
510
- : array(shape, default_strides(shape, sizeof (T)), ptr, base) { }
517
+ : array(dt, { count }, ptr, base) { }
518
+
519
+ template <typename T, typename ShapeIt, typename StridesIt,
520
+ typename = detail::enable_if_t <detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
521
+ array (ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
522
+ const T *ptr = nullptr , handle base = handle())
523
+ : array(pybind11::dtype::of<T>(), std::move(shape_first), std::move(shape_last),
524
+ std::move(strides_first), std::move(strides_last), ptr, base) { }
525
+
526
+ template <typename T, typename Shape, typename Strides,
527
+ typename = decltype(std::begin(std::declval<const Shape &>())),
528
+ typename = decltype(std::begin(std::declval<const Strides &>()))>
529
+ array (const Shape &shape, const Strides &strides, const T *ptr, handle base = handle())
530
+ : array(pybind11::dtype::of<T>(), shape, strides, ptr, base) { }
531
+
532
+ template <typename T, typename TSh, typename TSt>
533
+ array (const std::initializer_list<TSh> &shape, const std::initializer_list<TSt> &strides,
534
+ const T *ptr, handle base = handle())
535
+ : array(pybind11::dtype::of<T>(), shape, strides, ptr, base) { }
536
+
537
+ template <typename T, typename Shape, typename = decltype(std::begin(std::declval<const Shape &>()))>
538
+ array (const Shape &shape, const T *ptr, handle base = handle())
539
+ : array(shape, default_strides({std::begin (shape), std::end (shape)}, sizeof (T)), ptr, base) { }
540
+
541
+ template <typename T, typename TSh>
542
+ array (const std::initializer_list<TSh> &shape, const T *ptr, handle base = handle())
543
+ : array(shape, default_strides({std::begin (shape), std::end (shape)}, sizeof (T)), ptr, base) { }
511
544
512
545
template <typename T>
513
546
array (size_t count, const T *ptr, handle base = handle())
514
- : array(std::vector< size_t > { count }, ptr, base) { }
547
+ : array({ count }, ptr, base) { }
515
548
516
549
explicit array (const buffer_info &info)
517
550
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
@@ -729,13 +762,29 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
729
762
730
763
explicit array_t (const buffer_info& info) : array(info) { }
731
764
732
- array_t (const std::vector<size_t > &shape,
733
- const std::vector<size_t > &strides, const T *ptr = nullptr ,
734
- handle base = handle())
765
+ template <typename ShapeIt, typename StridesIt,
766
+ typename = detail::enable_if_t <detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
767
+ array_t (ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
768
+ const T *ptr = nullptr , handle base = handle())
769
+ : array(std::move(shape_first), std::move(shape_last), std::move(strides_first), std::move(strides_last), ptr, base) { }
770
+
771
+ template <typename Shape, typename Strides,
772
+ typename = decltype(std::begin(std::declval<const Shape &>())),
773
+ typename = decltype(std::begin(std::declval<const Strides &>()))>
774
+ array_t (const Shape &shape, const Strides &strides, const T *ptr = nullptr , handle base = handle())
775
+ : array(shape, strides, ptr, base) { }
776
+
777
+ template <typename TSh, typename TSt>
778
+ array_t (const std::initializer_list<TSh> &shape, const std::initializer_list<TSt> &strides,
779
+ const T *ptr = nullptr , handle base = handle())
735
780
: array(shape, strides, ptr, base) { }
736
781
737
- explicit array_t (const std::vector<size_t > &shape, const T *ptr = nullptr ,
738
- handle base = handle())
782
+ template <typename Shape, typename = decltype(std::begin(std::declval<const Shape &>()))>
783
+ explicit array_t (const Shape &shape, const T *ptr = nullptr , handle base = handle())
784
+ : array(shape, ptr, base) { }
785
+
786
+ template <typename TSh>
787
+ explicit array_t (const std::initializer_list<TSh> &shape, const T *ptr = nullptr , handle base = handle())
739
788
: array(shape, ptr, base) { }
740
789
741
790
explicit array_t (size_t count, const T *ptr = nullptr , handle base = handle())
0 commit comments