Skip to content

Commit a9605dc

Browse files
committed
Make iterators covariant in element type
The internal Baseiter type underlies most of the ndarray iterators, and it used `*mut A` for element type A. Update it to use `NonNull<A>` which behaves identically except it's guaranteed to be non-null and is covariant w.r.t the parameter A. Add compile test from the issue. Fixes #1290
1 parent 84fe611 commit a9605dc

File tree

5 files changed

+53
-21
lines changed

5 files changed

+53
-21
lines changed

src/impl_owned_array.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use alloc::vec::Vec;
33
use std::mem;
44
use std::mem::MaybeUninit;
55

6-
#[allow(unused_imports)]
6+
#[allow(unused_imports)] // Needed for Rust 1.64
77
use rawpointer::PointerExt;
88

99
use crate::imp_prelude::*;
@@ -907,7 +907,7 @@ where D: Dimension
907907

908908
// iter is a raw pointer iterator traversing the array in memory order now with the
909909
// sorted axes.
910-
let mut iter = Baseiter::new(self_.ptr.as_ptr(), self_.dim, self_.strides);
910+
let mut iter = Baseiter::new(self_.ptr, self_.dim, self_.strides);
911911
let mut dropped_elements = 0;
912912

913913
let mut last_ptr = data_ptr;

src/impl_views/conversions.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ where D: Dimension
199199
#[inline]
200200
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
201201
{
202-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
202+
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
203203
}
204204
}
205205

@@ -209,7 +209,7 @@ where D: Dimension
209209
#[inline]
210210
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
211211
{
212-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
212+
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
213213
}
214214
}
215215

@@ -220,7 +220,7 @@ where D: Dimension
220220
#[inline]
221221
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
222222
{
223-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
223+
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
224224
}
225225

226226
#[inline]
@@ -262,7 +262,7 @@ where D: Dimension
262262
#[inline]
263263
pub(crate) fn into_base_iter(self) -> Baseiter<A, D>
264264
{
265-
unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) }
265+
unsafe { Baseiter::new(self.ptr, self.dim, self.strides) }
266266
}
267267

268268
#[inline]

src/iterators/into_iter.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,15 @@ impl<A, D> IntoIter<A, D>
3333
where D: Dimension
3434
{
3535
/// Create a new by-value iterator that consumes `array`
36-
pub(crate) fn new(mut array: Array<A, D>) -> Self
36+
pub(crate) fn new(array: Array<A, D>) -> Self
3737
{
3838
unsafe {
3939
let array_head_ptr = array.ptr;
40-
let ptr = array.as_mut_ptr();
4140
let mut array_data = array.data;
4241
let data_len = array_data.release_all_elements();
4342
debug_assert!(data_len >= array.dim.size());
4443
let has_unreachable_elements = array.dim.size() != data_len;
45-
let inner = Baseiter::new(ptr, array.dim, array.strides);
44+
let inner = Baseiter::new(array_head_ptr, array.dim, array.strides);
4645

4746
IntoIter {
4847
array_data,

src/iterators/mod.rs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ use alloc::vec::Vec;
1919
use std::iter::FromIterator;
2020
use std::marker::PhantomData;
2121
use std::ptr;
22+
use std::ptr::NonNull;
23+
24+
#[allow(unused_imports)] // Needed for Rust 1.64
25+
use rawpointer::PointerExt;
2226

2327
use crate::Ix1;
2428

@@ -38,7 +42,7 @@ use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};
3842
#[derive(Debug)]
3943
pub struct Baseiter<A, D>
4044
{
41-
ptr: *mut A,
45+
ptr: NonNull<A>,
4246
dim: D,
4347
strides: D,
4448
index: Option<D>,
@@ -50,7 +54,7 @@ impl<A, D: Dimension> Baseiter<A, D>
5054
/// to be correct to avoid performing an unsafe pointer offset while
5155
/// iterating.
5256
#[inline]
53-
pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter<A, D>
57+
pub unsafe fn new(ptr: NonNull<A>, len: D, stride: D) -> Baseiter<A, D>
5458
{
5559
Baseiter {
5660
ptr,
@@ -74,7 +78,7 @@ impl<A, D: Dimension> Iterator for Baseiter<A, D>
7478
};
7579
let offset = D::stride_offset(&index, &self.strides);
7680
self.index = self.dim.next_for(index);
77-
unsafe { Some(self.ptr.offset(offset)) }
81+
unsafe { Some(self.ptr.offset(offset).as_ptr()) }
7882
}
7983

8084
fn size_hint(&self) -> (usize, Option<usize>)
@@ -99,7 +103,7 @@ impl<A, D: Dimension> Iterator for Baseiter<A, D>
99103
let mut i = 0;
100104
let i_end = len - elem_index;
101105
while i < i_end {
102-
accum = g(accum, row_ptr.offset(i as isize * stride));
106+
accum = g(accum, row_ptr.offset(i as isize * stride).as_ptr());
103107
i += 1;
104108
}
105109
}
@@ -140,12 +144,12 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
140144
Some(ix) => ix,
141145
};
142146
self.dim[0] -= 1;
143-
let offset = <_>::stride_offset(&self.dim, &self.strides);
147+
let offset = Ix1::stride_offset(&self.dim, &self.strides);
144148
if index == self.dim {
145149
self.index = None;
146150
}
147151

148-
unsafe { Some(self.ptr.offset(offset)) }
152+
unsafe { Some(self.ptr.offset(offset).as_ptr()) }
149153
}
150154

151155
fn nth_back(&mut self, n: usize) -> Option<*mut A>
@@ -154,11 +158,11 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
154158
let len = self.dim[0] - index[0];
155159
if n < len {
156160
self.dim[0] -= n + 1;
157-
let offset = <_>::stride_offset(&self.dim, &self.strides);
161+
let offset = Ix1::stride_offset(&self.dim, &self.strides);
158162
if index == self.dim {
159163
self.index = None;
160164
}
161-
unsafe { Some(self.ptr.offset(offset)) }
165+
unsafe { Some(self.ptr.offset(offset).as_ptr()) }
162166
} else {
163167
self.index = None;
164168
None
@@ -178,7 +182,8 @@ impl<A> DoubleEndedIterator for Baseiter<A, Ix1>
178182
accum = g(
179183
accum,
180184
self.ptr
181-
.offset(Ix1::stride_offset(&self.dim, &self.strides)),
185+
.offset(Ix1::stride_offset(&self.dim, &self.strides))
186+
.as_ptr(),
182187
);
183188
}
184189
}

tests/iterators.rs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
#![allow(
2-
clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names
3-
)]
1+
#![allow(clippy::deref_addrof, clippy::unreadable_literal)]
42

53
use ndarray::prelude::*;
64
use ndarray::{arr3, indices, s, Slice, Zip};
@@ -1055,3 +1053,33 @@ impl Drop for DropCount<'_>
10551053
self.drops.set(self.drops.get() + 1);
10561054
}
10571055
}
1056+
1057+
#[test]
1058+
fn test_impl_iter_compiles()
1059+
{
1060+
// Requires that the iterators are covariant in the element type
1061+
1062+
// base case: std
1063+
fn slice_iter_non_empty_indices<'s, 'a>(array: &'a Vec<&'s str>) -> impl Iterator<Item = usize> + 'a
1064+
{
1065+
array
1066+
.iter()
1067+
.enumerate()
1068+
.filter(|(_index, elem)| !elem.is_empty())
1069+
.map(|(index, _elem)| index)
1070+
}
1071+
1072+
let _ = slice_iter_non_empty_indices;
1073+
1074+
// ndarray case
1075+
fn array_iter_non_empty_indices<'s, 'a>(array: &'a Array<&'s str, Ix1>) -> impl Iterator<Item = usize> + 'a
1076+
{
1077+
array
1078+
.iter()
1079+
.enumerate()
1080+
.filter(|(_index, elem)| !elem.is_empty())
1081+
.map(|(index, _elem)| index)
1082+
}
1083+
1084+
let _ = array_iter_non_empty_indices;
1085+
}

0 commit comments

Comments
 (0)