@@ -316,19 +316,75 @@ class dtype : public object {
316
316
}
317
317
};
318
318
319
- class array : public buffer {
319
+ NAMESPACE_BEGIN (detail)
320
+ [[noreturn]] PYBIND11_NOINLINE inline void fail_dim_check(size_t dim, size_t ndim, const std::string& msg) {
321
+ throw index_error (msg + " : " + std::to_string (dim) + " (ndim = " + std::to_string (ndim) + " )" );
322
+ }
323
+ NAMESPACE_END (detail)
324
+
325
+ class safe_access_policy {
326
+ public:
327
+ void check_axis (size_t dim, size_t ndim) const {
328
+ if (dim >= ndim) {
329
+ detail::fail_dim_check (dim, ndim, " invalid axis" );
330
+ }
331
+ }
332
+
333
+ template <typename ... Ix>
334
+ void check_indices (size_t ndim, Ix...) const
335
+ {
336
+ if (sizeof ...(Ix) > ndim) {
337
+ detail::fail_dim_check (sizeof ...(Ix), ndim, " too many indices for an array" );
338
+ }
339
+ }
340
+
341
+ template <typename ... Ix>
342
+ void check_dimensions (const size_t * shape, Ix... index ) const {
343
+ check_dimensions_impl (size_t (0 ), shape, size_t (index )...);
344
+ }
345
+
346
+ private:
347
+ void check_dimensions_impl (size_t , const size_t *) const { }
348
+
349
+ template <typename ... Ix>
350
+ void check_dimensions_impl (size_t axis, const size_t * shape, size_t i, Ix... index ) const {
351
+ if (i >= *shape) {
352
+ throw index_error (std::string (" index " ) + std::to_string (i) +
353
+ " is out of bounds for axis " + std::to_string (axis) +
354
+ " with size " + std::to_string (*shape));
355
+ }
356
+ check_dimensions_impl (axis + 1 , shape + 1 , index ...);
357
+ }
358
+ };
359
+
360
+ class unsafe_access_policy {
361
+ public:
362
+ void check_axis (size_t , size_t ) const {
363
+ }
364
+
365
+ template <typename ... Ix>
366
+ void check_indices (size_t , Ix...) const {
367
+ }
368
+
369
+ template <typename ... Ix>
370
+ void check_dimensions (const size_t *, Ix...) const {
371
+ }
372
+ };
373
+
374
+ template <class access_policy = safe_access_policy>
375
+ class array_base : public buffer , private access_policy {
320
376
public:
321
- PYBIND11_OBJECT_CVT (array , buffer, detail::npy_api::get().PyArray_Check_, raw_array)
377
+ PYBIND11_OBJECT_CVT (array_base , buffer, detail::npy_api::get().PyArray_Check_, raw_array)
322
378
323
379
enum {
324
380
c_style = detail::npy_api::NPY_C_CONTIGUOUS_,
325
381
f_style = detail::npy_api::NPY_F_CONTIGUOUS_,
326
382
forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
327
383
};
328
384
329
- array () : array (0 , static_cast <const double *>(nullptr )) {}
385
+ array_base () : array_base (0 , static_cast <const double *>(nullptr )) {}
330
386
331
- array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
387
+ array_base (const pybind11::dtype &dt, const std::vector<size_t > &shape,
332
388
const std::vector<size_t > &strides, const void *ptr = nullptr ,
333
389
handle base = handle()) {
334
390
auto & api = detail::npy_api::get ();
@@ -362,30 +418,30 @@ class array : public buffer {
362
418
m_ptr = tmp.release ().ptr ();
363
419
}
364
420
365
- array (const pybind11::dtype &dt, const std::vector<size_t > &shape,
421
+ array_base (const pybind11::dtype &dt, const std::vector<size_t > &shape,
366
422
const void *ptr = nullptr , handle base = handle())
367
- : array (dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
423
+ : array_base (dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
368
424
369
- array (const pybind11::dtype &dt, size_t count, const void *ptr = nullptr ,
425
+ array_base (const pybind11::dtype &dt, size_t count, const void *ptr = nullptr ,
370
426
handle base = handle())
371
- : array (dt, std::vector<size_t >{ count }, ptr, base) { }
427
+ : array_base (dt, std::vector<size_t >{ count }, ptr, base) { }
372
428
373
- template <typename T> array (const std::vector<size_t >& shape,
429
+ template <typename T> array_base (const std::vector<size_t >& shape,
374
430
const std::vector<size_t >& strides,
375
431
const T* ptr, handle base = handle())
376
- : array (pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
432
+ : array_base (pybind11::dtype::of<T>(), shape, strides, (void *) ptr, base) { }
377
433
378
434
template <typename T>
379
- array (const std::vector<size_t > &shape, const T *ptr,
435
+ array_base (const std::vector<size_t > &shape, const T *ptr,
380
436
handle base = handle())
381
- : array (shape, default_strides(shape, sizeof (T)), ptr, base) { }
437
+ : array_base (shape, default_strides(shape, sizeof (T)), ptr, base) { }
382
438
383
439
template <typename T>
384
- array (size_t count, const T *ptr, handle base = handle())
385
- : array (std::vector<size_t >{ count }, ptr, base) { }
440
+ array_base (size_t count, const T *ptr, handle base = handle())
441
+ : array_base (std::vector<size_t >{ count }, ptr, base) { }
386
442
387
- explicit array (const buffer_info &info)
388
- : array (pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
443
+ explicit array_base (const buffer_info &info)
444
+ : array_base (pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
389
445
390
446
// / Array descriptor (dtype)
391
447
pybind11::dtype dtype () const {
@@ -424,8 +480,7 @@ class array : public buffer {
424
480
425
481
// / Dimension along a given axis
426
482
size_t shape (size_t dim) const {
427
- if (dim >= ndim ())
428
- fail_dim_check (dim, " invalid axis" );
483
+ access_policy::check_axis (dim, ndim ());
429
484
return shape ()[dim];
430
485
}
431
486
@@ -436,8 +491,7 @@ class array : public buffer {
436
491
437
492
// / Stride along a given axis
438
493
size_t strides (size_t dim) const {
439
- if (dim >= ndim ())
440
- fail_dim_check (dim, " invalid axis" );
494
+ access_policy::check_axis (dim, ndim ());
441
495
return strides ()[dim];
442
496
}
443
497
@@ -473,8 +527,7 @@ class array : public buffer {
473
527
// / Byte offset from beginning of the array to a given index (full or partial).
474
528
// / May throw if the index would lead to out of bounds access.
475
529
template <typename ... Ix> size_t offset_at (Ix... index) const {
476
- if (sizeof ...(index ) > ndim ())
477
- fail_dim_check (sizeof ...(index ), " too many indices for an array" );
530
+ access_policy::check_indices (ndim (), index ...);
478
531
return byte_offset (size_t (index )...);
479
532
}
480
533
@@ -487,15 +540,15 @@ class array : public buffer {
487
540
}
488
541
489
542
// / Return a new view with all of the dimensions of length 1 removed
490
- array squeeze () {
543
+ array_base squeeze () {
491
544
auto & api = detail::npy_api::get ();
492
- return reinterpret_steal<array >(api.PyArray_Squeeze_ (m_ptr));
545
+ return reinterpret_steal<array_base >(api.PyArray_Squeeze_ (m_ptr));
493
546
}
494
547
495
548
// / Ensure that the argument is a NumPy array
496
549
// / In case of an error, nullptr is returned and the Python error is cleared.
497
- static array ensure (handle h, int ExtraFlags = 0 ) {
498
- auto result = reinterpret_steal<array >(raw_array (h.ptr (), ExtraFlags));
550
+ static array_base ensure (handle h, int ExtraFlags = 0 ) {
551
+ auto result = reinterpret_steal<array_base >(raw_array (h.ptr (), ExtraFlags));
499
552
if (!result)
500
553
PyErr_Clear ();
501
554
return result;
@@ -504,13 +557,8 @@ class array : public buffer {
504
557
protected:
505
558
template <typename , typename > friend struct detail ::npy_format_descriptor;
506
559
507
- void fail_dim_check (size_t dim, const std::string& msg) const {
508
- throw index_error (msg + " : " + std::to_string (dim) +
509
- " (ndim = " + std::to_string (ndim ()) + " )" );
510
- }
511
-
512
560
template <typename ... Ix> size_t byte_offset (Ix... index) const {
513
- check_dimensions (index ...);
561
+ access_policy:: check_dimensions (shape (), index ...);
514
562
return byte_offset_unsafe (index ...);
515
563
}
516
564
@@ -537,21 +585,6 @@ class array : public buffer {
537
585
return strides;
538
586
}
539
587
540
- template <typename ... Ix> void check_dimensions (Ix... index) const {
541
- check_dimensions_impl (size_t (0 ), shape (), size_t (index )...);
542
- }
543
-
544
- void check_dimensions_impl (size_t , const size_t *) const { }
545
-
546
- template <typename ... Ix> void check_dimensions_impl (size_t axis, const size_t * shape, size_t i, Ix... index) const {
547
- if (i >= *shape) {
548
- throw index_error (std::string (" index " ) + std::to_string (i) +
549
- " is out of bounds for axis " + std::to_string (axis) +
550
- " with size " + std::to_string (*shape));
551
- }
552
- check_dimensions_impl (axis + 1 , shape + 1 , index ...);
553
- }
554
-
555
588
// / Create array from any object -- always returns a new reference
556
589
static PyObject *raw_array (PyObject *ptr, int ExtraFlags = 0 ) {
557
590
if (ptr == nullptr )
@@ -561,35 +594,40 @@ class array : public buffer {
561
594
}
562
595
};
563
596
564
- template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
597
+ using array = array_base<safe_access_policy>;
598
+ using array_unchecked = array_base<unsafe_access_policy>;
599
+
600
+ template <typename T, int ExtraFlags = array_base<>::forcecast, class access_policy = safe_access_policy>
601
+ class array_t : public array_base <access_policy> {
565
602
public:
566
- array_t () : array(0 , static_cast <const T *>(nullptr )) {}
567
- array_t (handle h, borrowed_t ) : array(h, borrowed) { }
568
- array_t (handle h, stolen_t ) : array(h, stolen) { }
603
+ using base_type = array_base<access_policy>;
604
+ array_t () : base_type(0 , static_cast <const T *>(nullptr )) {}
605
+ array_t (handle h, borrowed_t ) : base_type(h, borrowed) { }
606
+ array_t (handle h, stolen_t ) : base_type(h, stolen) { }
569
607
570
608
PYBIND11_DEPRECATED (" Use array_t<T>::ensure() instead" )
571
- array_t (handle h, bool is_borrowed) : array (raw_array_t (h.ptr()), stolen) {
609
+ array_t (handle h, bool is_borrowed) : base_type (raw_array_t (h.ptr()), stolen) {
572
610
if (!m_ptr) PyErr_Clear ();
573
611
if (!is_borrowed) Py_XDECREF (h.ptr ());
574
612
}
575
613
576
- array_t (const object &o) : array (raw_array_t (o.ptr()), stolen) {
614
+ array_t (const object &o) : base_type (raw_array_t (o.ptr()), stolen) {
577
615
if (!m_ptr) throw error_already_set ();
578
616
}
579
617
580
- explicit array_t (const buffer_info& info) : array (info) { }
618
+ explicit array_t (const buffer_info& info) : base_type (info) { }
581
619
582
620
array_t (const std::vector<size_t > &shape,
583
621
const std::vector<size_t > &strides, const T *ptr = nullptr ,
584
622
handle base = handle())
585
- : array (shape, strides, ptr, base) { }
623
+ : base_type (shape, strides, ptr, base) { }
586
624
587
625
explicit array_t (const std::vector<size_t > &shape, const T *ptr = nullptr ,
588
626
handle base = handle())
589
- : array (shape, ptr, base) { }
627
+ : base_type (shape, ptr, base) { }
590
628
591
629
explicit array_t (size_t count, const T *ptr = nullptr , handle base = handle())
592
- : array (count, ptr, base) { }
630
+ : base_type (count, ptr, base) { }
593
631
594
632
constexpr size_t itemsize () const {
595
633
return sizeof (T);
@@ -600,25 +638,25 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public
600
638
}
601
639
602
640
template <typename ... Ix> const T* data (Ix... index) const {
603
- return static_cast <const T*>(array ::data (index ...));
641
+ return static_cast <const T*>(base_type ::data (index ...));
604
642
}
605
643
606
644
template <typename ... Ix> T* mutable_data (Ix... index) {
607
- return static_cast <T*>(array ::mutable_data (index ...));
645
+ return static_cast <T*>(base_type ::mutable_data (index ...));
608
646
}
609
647
610
648
// Reference to element at a given index
611
649
template <typename ... Ix> const T& at (Ix... index) const {
612
650
if (sizeof ...(index ) != ndim ())
613
- fail_dim_check (sizeof ...(index ), " index dimension mismatch" );
614
- return *(static_cast <const T*>(array ::data ()) + byte_offset (size_t (index )...) / itemsize ());
651
+ detail:: fail_dim_check (sizeof ...(index ), ndim ( ), " index dimension mismatch" );
652
+ return *(static_cast <const T*>(base_type ::data ()) + byte_offset (size_t (index )...) / itemsize ());
615
653
}
616
654
617
655
// Mutable reference to element at a given index
618
656
template <typename ... Ix> T& mutable_at (Ix... index) {
619
657
if (sizeof ...(index ) != ndim ())
620
- fail_dim_check (sizeof ...(index ), " index dimension mismatch" );
621
- return *(static_cast <T*>(array ::mutable_data ()) + byte_offset (size_t (index )...) / itemsize ());
658
+ detail:: fail_dim_check (sizeof ...(index ), ndim ( ), " index dimension mismatch" );
659
+ return *(static_cast <T*>(base_type ::mutable_data ()) + byte_offset (size_t (index )...) / itemsize ());
622
660
}
623
661
624
662
// / Ensure that the argument is a NumPy array of the correct dtype.
@@ -811,7 +849,7 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
811
849
812
850
// Sanity check: verify that NumPy properly parses our buffer format string
813
851
auto & api = npy_api::get ();
814
- auto arr = array (buffer_info (nullptr , itemsize, format_str, 1 ));
852
+ auto arr = array_base<> (buffer_info (nullptr , itemsize, format_str, 1 ));
815
853
if (!api.PyArray_EquivTypes_ (dtype_ptr, arr.dtype ().ptr ()))
816
854
pybind11_fail (" NumPy: invalid buffer descriptor!" );
817
855
@@ -1076,11 +1114,11 @@ struct vectorize_helper {
1076
1114
template <typename T>
1077
1115
explicit vectorize_helper (T&&f) : f(std::forward<T>(f)) { }
1078
1116
1079
- object operator ()(array_t <Args, array ::c_style | array ::forcecast>... args) {
1117
+ object operator ()(array_t <Args, array_base<> ::c_style | array_base<> ::forcecast>... args) {
1080
1118
return run (args..., make_index_sequence<sizeof ...(Args)>());
1081
1119
}
1082
1120
1083
- template <size_t ... Index> object run (array_t <Args, array ::c_style | array ::forcecast>&... args, index_sequence<Index...> index) {
1121
+ template <size_t ... Index> object run (array_t <Args, array_base<> ::c_style | array_base<> ::forcecast>&... args, index_sequence<Index...> index) {
1084
1122
/* Request buffers from all parameters */
1085
1123
const size_t N = sizeof ...(Args);
1086
1124
0 commit comments