@@ -316,7 +316,64 @@ class dtype : public object {
316
316
}
317
317
};
318
318
319
- class array : public buffer {
319
+ namespace detail {
320
+ inline void fail_dim_check (size_t dim, size_t ndim, const std::string& msg) {
321
+ throw index_error (msg + " : " + std::to_string (dim) +
322
+ " (ndim = " + std::to_string (ndim) + " )" );
323
+ }
324
+ }
325
+
326
+ class safe_access_policy {
327
+ public:
328
+ void check_axis (size_t dim, size_t ndim) const {
329
+ if (dim >= ndim) {
330
+ detail::fail_dim_check (dim, ndim, " invalid axis" );
331
+ }
332
+ }
333
+
334
+ template <typename ... Ix>
335
+ void check_indices (size_t ndim, Ix...) const
336
+ {
337
+ if (sizeof ...(Ix) > ndim) {
338
+ detail::fail_dim_check (sizeof ...(Ix), ndim, " too many indices for an array" );
339
+ }
340
+ }
341
+
342
+ template <typename ... Ix>
343
+ void check_dimensions (const size_t * shape, Ix... index) const {
344
+ check_dimensions_impl (size_t (0 ), shape, size_t (index )...);
345
+ }
346
+
347
+ private:
348
+ void check_dimensions_impl (size_t , const size_t *) const { }
349
+
350
+ template <typename ... Ix>
351
+ void check_dimensions_impl (size_t axis, const size_t * shape, size_t i, Ix... index) const {
352
+ if (i >= *shape) {
353
+ throw index_error (std::string (" index " ) + std::to_string (i) +
354
+ " is out of bounds for axis " + std::to_string (axis) +
355
+ " with size " + std::to_string (*shape));
356
+ }
357
+ check_dimensions_impl (axis + 1 , shape + 1 , index ...);
358
+ }
359
+ };
360
+
361
+ class unsafe_access_policy {
362
+ public:
363
+ void check_axis (size_t , size_t ) const {
364
+ }
365
+
366
+ template <typename ... Ix>
367
+ void check_indices (size_t , Ix...) const {
368
+ }
369
+
370
+ template <typename ... Ix>
371
+ void check_dimensions (const size_t *, Ix...) const {
372
+ }
373
+ };
374
+
375
+ template <class access_policy = safe_access_policy>
376
+ class array : public buffer , private access_policy {
320
377
public:
321
378
PYBIND11_OBJECT_CVT (array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
322
379
@@ -424,8 +481,7 @@ class array : public buffer {
424
481
425
482
// / Dimension along a given axis
426
483
size_t shape (size_t dim) const {
427
- if (dim >= ndim ())
428
- fail_dim_check (dim, " invalid axis" );
484
+ access_policy::check_axis (dim, ndim ());
429
485
return shape ()[dim];
430
486
}
431
487
@@ -436,8 +492,7 @@ class array : public buffer {
436
492
437
493
// / Stride along a given axis
438
494
size_t strides (size_t dim) const {
439
- if (dim >= ndim ())
440
- fail_dim_check (dim, " invalid axis" );
495
+ access_policy::check_axis (dim, ndim ());
441
496
return strides ()[dim];
442
497
}
443
498
@@ -473,8 +528,7 @@ class array : public buffer {
473
528
// / Byte offset from beginning of the array to a given index (full or partial).
474
529
// / May throw if the index would lead to out of bounds access.
475
530
template <typename ... Ix> size_t offset_at (Ix... index) const {
476
- if (sizeof ...(index ) > ndim ())
477
- fail_dim_check (sizeof ...(index ), " too many indices for an array" );
531
+ access_policy::check_indices (ndim (), index ...);
478
532
return byte_offset (size_t (index )...);
479
533
}
480
534
@@ -504,13 +558,8 @@ class array : public buffer {
504
558
protected:
505
559
template <typename , typename > friend struct detail ::npy_format_descriptor;
506
560
507
- void fail_dim_check (size_t dim, const std::string& msg) const {
508
- throw index_error (msg + " : " + std::to_string (dim) +
509
- " (ndim = " + std::to_string (ndim ()) + " )" );
510
- }
511
-
512
561
template <typename ... Ix> size_t byte_offset (Ix... index) const {
513
- check_dimensions (index ...);
562
+ access_policy:: check_dimensions (shape (), index ...);
514
563
return byte_offset_unsafe (index ...);
515
564
}
516
565
@@ -537,21 +586,6 @@ class array : public buffer {
537
586
return strides;
538
587
}
539
588
540
- template <typename ... Ix> void check_dimensions (Ix... index) const {
541
- check_dimensions_impl (size_t (0 ), shape (), size_t (index )...);
542
- }
543
-
544
- void check_dimensions_impl (size_t , const size_t *) const { }
545
-
546
- template <typename ... Ix> void check_dimensions_impl (size_t axis, const size_t * shape, size_t i, Ix... index) const {
547
- if (i >= *shape) {
548
- throw index_error (std::string (" index " ) + std::to_string (i) +
549
- " is out of bounds for axis " + std::to_string (axis) +
550
- " with size " + std::to_string (*shape));
551
- }
552
- check_dimensions_impl (axis + 1 , shape + 1 , index ...);
553
- }
554
-
555
589
// / Create array from any object -- always returns a new reference
556
590
static PyObject *raw_array (PyObject *ptr, int ExtraFlags = 0 ) {
557
591
if (ptr == nullptr )
@@ -561,35 +595,36 @@ class array : public buffer {
561
595
}
562
596
};
563
597
564
- template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
598
+ template <typename T, int ExtraFlags = array<> ::forcecast> class array_t : public array <> {
565
599
public:
566
- array_t () : array(0 , static_cast <const T *>(nullptr )) {}
567
- array_t (handle h, borrowed_t ) : array(h, borrowed) { }
568
- array_t (handle h, stolen_t ) : array(h, stolen) { }
600
+ using base_type = array<>;
601
+ array_t () : base_type(0 , static_cast <const T *>(nullptr )) {}
602
+ array_t (handle h, borrowed_t ) : base_type(h, borrowed) { }
603
+ array_t (handle h, stolen_t ) : base_type(h, stolen) { }
569
604
570
605
PYBIND11_DEPRECATED (" Use array_t<T>::ensure() instead" )
571
- array_t (handle h, bool is_borrowed) : array (raw_array_t (h.ptr()), stolen) {
606
+ array_t (handle h, bool is_borrowed) : base_type (raw_array_t (h.ptr()), stolen) {
572
607
if (!m_ptr) PyErr_Clear ();
573
608
if (!is_borrowed) Py_XDECREF (h.ptr ());
574
609
}
575
610
576
- array_t (const object &o) : array (raw_array_t (o.ptr()), stolen) {
611
+ array_t (const object &o) : base_type (raw_array_t (o.ptr()), stolen) {
577
612
if (!m_ptr) throw error_already_set ();
578
613
}
579
614
580
- explicit array_t (const buffer_info& info) : array (info) { }
615
+ explicit array_t (const buffer_info& info) : base_type (info) { }
581
616
582
617
array_t (const std::vector<size_t > &shape,
583
618
const std::vector<size_t > &strides, const T *ptr = nullptr ,
584
619
handle base = handle())
585
- : array (shape, strides, ptr, base) { }
620
+ : base_type (shape, strides, ptr, base) { }
586
621
587
622
explicit array_t (const std::vector<size_t > &shape, const T *ptr = nullptr ,
588
623
handle base = handle())
589
- : array (shape, ptr, base) { }
624
+ : base_type (shape, ptr, base) { }
590
625
591
626
explicit array_t (size_t count, const T *ptr = nullptr , handle base = handle())
592
- : array (count, ptr, base) { }
627
+ : base_type (count, ptr, base) { }
593
628
594
629
constexpr size_t itemsize () const {
595
630
return sizeof (T);
@@ -600,25 +635,25 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
600
635
}
601
636
602
637
template <typename ... Ix> const T* data (Ix... index) const {
603
- return static_cast <const T*>(array ::data (index ...));
638
+ return static_cast <const T*>(base_type ::data (index ...));
604
639
}
605
640
606
641
template <typename ... Ix> T* mutable_data (Ix... index) {
607
- return static_cast <T*>(array ::mutable_data (index ...));
642
+ return static_cast <T*>(base_type ::mutable_data (index ...));
608
643
}
609
644
610
645
// Reference to element at a given index
611
646
template <typename ... Ix> const T& at (Ix... index) const {
612
647
if (sizeof ...(index ) != ndim ())
613
- fail_dim_check (sizeof ...(index ), " index dimension mismatch" );
614
- return *(static_cast <const T*>(array ::data ()) + byte_offset (size_t (index )...) / itemsize ());
648
+ detail:: fail_dim_check (sizeof ...(index ), ndim ( ), " index dimension mismatch" );
649
+ return *(static_cast <const T*>(base_type ::data ()) + byte_offset (size_t (index )...) / itemsize ());
615
650
}
616
651
617
652
// Mutable reference to element at a given index
618
653
template <typename ... Ix> T& mutable_at (Ix... index) {
619
654
if (sizeof ...(index ) != ndim ())
620
- fail_dim_check (sizeof ...(index ), " index dimension mismatch" );
621
- return *(static_cast <T*>(array ::mutable_data ()) + byte_offset (size_t (index )...) / itemsize ());
655
+ detail:: fail_dim_check (sizeof ...(index ), ndim ( ), " index dimension mismatch" );
656
+ return *(static_cast <T*>(base_type ::mutable_data ()) + byte_offset (size_t (index )...) / itemsize ());
622
657
}
623
658
624
659
// / Ensure that the argument is a NumPy array of the correct dtype.
@@ -811,7 +846,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
811
846
812
847
// Sanity check: verify that NumPy properly parses our buffer format string
813
848
auto & api = npy_api::get ();
814
- auto arr = array (buffer_info (nullptr , itemsize, format_str, 1 ));
849
+ auto arr = array<> (buffer_info (nullptr , itemsize, format_str, 1 ));
815
850
if (!api.PyArray_EquivTypes_ (dtype_ptr, arr.dtype ().ptr ()))
816
851
pybind11_fail (" NumPy: invalid buffer descriptor!" );
817
852
@@ -1076,11 +1111,11 @@ struct vectorize_helper {
1076
1111
template <typename T>
1077
1112
explicit vectorize_helper (T&&f) : f(std::forward<T>(f)) { }
1078
1113
1079
- object operator ()(array_t <Args, array::c_style | array::forcecast>... args) {
1114
+ object operator ()(array_t <Args, array<> ::c_style | array<> ::forcecast>... args) {
1080
1115
return run (args..., make_index_sequence<sizeof ...(Args)>());
1081
1116
}
1082
1117
1083
- template <size_t ... Index> object run (array_t <Args, array::c_style | array::forcecast>&... args, index_sequence<Index...> index) {
1118
+ template <size_t ... Index> object run (array_t <Args, array<> ::c_style | array<> ::forcecast>&... args, index_sequence<Index...> index) {
1084
1119
/* Request buffers from all parameters */
1085
1120
const size_t N = sizeof ...(Args);
1086
1121
0 commit comments