Skip to content

Commit ad7fdb6

Browse files
committed
Improve ptr_rotate performance, tests, and benchmarks
1 parent 890881f commit ad7fdb6

File tree

3 files changed

+216
-69
lines changed

3 files changed

+216
-69
lines changed

src/libcore/benches/slice.rs

+26
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,29 @@ fn binary_search_l2_with_dups(b: &mut Bencher) {
5555
fn binary_search_l3_with_dups(b: &mut Bencher) {
5656
binary_search(b, Cache::L3, |i| i / 16 * 16);
5757
}
58+
59+
macro_rules! rotate {
60+
($fn:ident, $n:expr, $mapper:expr) => {
61+
#[bench]
62+
fn $fn(b: &mut Bencher) {
63+
let mut x = (0usize..$n).map(&$mapper).collect::<Vec<_>>();
64+
b.iter(|| {
65+
for s in 0..x.len() {
66+
x[..].rotate_right(s);
67+
}
68+
black_box(x[0].clone())
69+
})
70+
}
71+
};
72+
}
73+
74+
#[derive(Clone)]
75+
struct Rgb(u8, u8, u8);
76+
77+
rotate!(rotate_u8, 32, |i| i as u8);
78+
rotate!(rotate_rgb, 32, |i| Rgb(i as u8, (i as u8).wrapping_add(7), (i as u8).wrapping_add(42)));
79+
rotate!(rotate_usize, 32, |i| i);
80+
rotate!(rotate_16_usize_4, 16, |i| [i; 4]);
81+
rotate!(rotate_16_usize_5, 16, |i| [i; 5]);
82+
rotate!(rotate_64_usize_4, 64, |i| [i; 4]);
83+
rotate!(rotate_64_usize_5, 64, |i| [i; 5]);

src/libcore/slice/rotate.rs

+152-69
Original file line numberDiff line numberDiff line change
@@ -2,88 +2,171 @@ use crate::cmp;
22
use crate::mem::{self, MaybeUninit};
33
use crate::ptr;
44

5-
/// Rotation is much faster if it has access to a little bit of memory. This
6-
/// union provides a RawVec-like interface, but to a fixed-size stack buffer.
7-
#[allow(unions_with_drop_fields)]
8-
union RawArray<T> {
9-
/// Ensure this is appropriately aligned for T, and is big
10-
/// enough for two elements even if T is enormous.
11-
typed: [T; 2],
12-
/// For normally-sized types, especially things like u8, having more
13-
/// than 2 in the buffer is necessary for usefulness, so pad it out
14-
/// enough to be helpful, but not so big as to risk overflow.
15-
_extra: [usize; 32],
16-
}
17-
18-
impl<T> RawArray<T> {
19-
fn capacity() -> usize {
20-
if mem::size_of::<T>() == 0 {
21-
usize::max_value()
22-
} else {
23-
mem::size_of::<Self>() / mem::size_of::<T>()
24-
}
25-
}
26-
}
27-
28-
/// Rotates the range `[mid-left, mid+right)` such that the element at `mid`
29-
/// becomes the first element. Equivalently, rotates the range `left`
30-
/// elements to the left or `right` elements to the right.
5+
/// Rotates the range `[mid-left, mid+right)` such that the element at `mid` becomes the first
6+
/// element. Equivalently, rotates the range `left` elements to the left or `right` elements to the
7+
/// right.
318
///
329
/// # Safety
3310
///
3411
/// The specified range must be valid for reading and writing.
3512
///
3613
/// # Algorithm
3714
///
38-
/// For longer rotations, swap the left-most `delta = min(left, right)`
39-
/// elements with the right-most `delta` elements. LLVM vectorizes this,
40-
/// which is profitable as we only reach this step for a "large enough"
41-
/// rotation. Doing this puts `delta` elements on the larger side into the
42-
/// correct position, leaving a smaller rotate problem. Demonstration:
43-
///
15+
/// Algorithm 1 is used for small values of `left + right` or for large `T`. The elements are moved
16+
/// into their final positions one at a time starting at `mid - left` and advancing by `right` steps
17+
/// modulo `left + right`, such that only one temporary is needed. Eventually, we arrive back at
18+
/// `mid - left`. However, if `gcd(left + right, right)` is not 1, the above steps skipped over
19+
/// elements. For example:
4420
/// ```text
45-
/// [ 6 7 8 9 10 11 12 13 . 1 2 3 4 5 ]
46-
/// 1 2 3 4 5 [ 11 12 13 . 6 7 8 9 10 ]
47-
/// 1 2 3 4 5 [ 8 9 10 . 6 7 ] 11 12 13
48-
/// 1 2 3 4 5 6 7 [ 10 . 8 9 ] 11 12 13
49-
/// 1 2 3 4 5 6 7 [ 9 . 8 ] 10 11 12 13
50-
/// 1 2 3 4 5 6 7 8 [ . ] 9 10 11 12 13
21+
/// left = 10, right = 6
22+
/// the `^` indicates an element in its final place
23+
/// 6 7 8 9 10 11 12 13 14 15 . 0 1 2 3 4 5
24+
/// after using one step of the above algorithm (The X will be overwritten at the end of the round,
25+
/// and 12 is stored in a temporary):
26+
/// X 7 8 9 10 11 6 13 14 15 . 0 1 2 3 4 5
27+
/// ^
28+
/// after using another step (now 2 is in the temporary):
29+
/// X 7 8 9 10 11 6 13 14 15 . 0 1 12 3 4 5
30+
/// ^ ^
31+
/// after the third step (the steps wrap around, and 8 is in the temporary):
32+
/// X 7 2 9 10 11 6 13 14 15 . 0 1 12 3 4 5
33+
/// ^ ^ ^
34+
/// after 7 more steps, the round ends with the temporary 0 getting put in the X:
35+
/// 0 7 2 9 4 11 6 13 8 15 . 10 1 12 3 14 5
36+
/// ^ ^ ^ ^ ^ ^ ^ ^
5137
/// ```
38+
/// Fortunately, the number of skipped over elements between finalized elements is always equal, so
39+
/// we can just offset our starting position and do more rounds (the total number of rounds is the
40+
/// `gcd(left + right, right)` value). The end result is that all elements are finalized once and
41+
/// only once.
42+
///
43+
/// Algorithm 2 is used if `left + right` is large but `min(left, right)` is small enough to
44+
/// fit onto a stack buffer. The `min(left, right)` elements are copied onto the buffer, `memmove`
45+
/// is applied to the others, and the ones on the buffer are moved back into the hole on the
46+
/// opposite side of where they originated.
47+
///
48+
/// Algorithms that can be vectorized outperform the above once `left + right` becomes large enough.
49+
/// Algorithm 1 can be vectorized by chunking and performing many rounds at once, but there are too
50+
/// few rounds on average until `left + right` is enormous, and the worst case of a single
51+
/// round is always there. Instead, algorithm 3 utilizes repeated swapping of
52+
/// `min(left, right)` elements until a smaller rotate problem is left.
5253
///
53-
/// Once the rotation is small enough, copy some elements into a stack
54-
/// buffer, `memmove` the others, and move the ones back from the buffer.
55-
pub unsafe fn ptr_rotate<T>(mut left: usize, mid: *mut T, mut right: usize) {
54+
/// ```text
55+
/// left = 11, right = 4
56+
/// [4 5 6 7 8 9 10 11 12 13 14 . 0 1 2 3]
57+
/// ^ ^ ^ ^ ^ ^ ^ ^ swapping the right most elements with elements to the left
58+
/// [4 5 6 7 8 9 10 . 0 1 2 3] 11 12 13 14
59+
/// ^ ^ ^ ^ ^ ^ ^ ^ swapping these
60+
/// [4 5 6 . 0 1 2 3] 7 8 9 10 11 12 13 14
61+
/// we cannot swap any more, but a smaller rotation problem is left to solve
62+
/// ```
63+
/// when `left < right` the swapping happens from the left instead.
64+
pub unsafe fn ptr_rotate<T>(mut left: usize, mut mid: *mut T, mut right: usize) {
65+
type BufType = [usize; 32];
66+
if mem::size_of::<T>() == 0 {
67+
return;
68+
}
5669
loop {
57-
let delta = cmp::min(left, right);
58-
if delta <= RawArray::<T>::capacity() {
59-
// We will always hit this immediately for ZST.
60-
break;
70+
// N.B. the below algorithms can fail if these cases are not checked
71+
if (right == 0) || (left == 0) {
72+
return;
6173
}
62-
63-
ptr::swap_nonoverlapping(
64-
mid.sub(left),
65-
mid.add(right - delta),
66-
delta);
67-
68-
if left <= right {
69-
right -= delta;
74+
if (left + right < 24) || (mem::size_of::<T>() > mem::size_of::<[usize; 4]>()) {
75+
// Algorithm 1
76+
// Microbenchmarks indicate that the average performance for random shifts is better all
77+
// the way until about `left + right == 32`, but the worst case performance breaks even
78+
// around 16. 24 was chosen as middle ground. If the size of `T` is larger than 4
79+
// `usize`s, this algorithm also outperforms other algorithms.
80+
let x = mid.sub(left);
81+
// beginning of first round
82+
let mut tmp: T = x.read();
83+
let mut i = right;
84+
// `gcd` can be found before hand by calculating `gcd(left + right, right)`,
85+
// but it is faster to do one loop which calculates the gcd as a side effect, then
86+
// doing the rest of the chunk
87+
let mut gcd = right;
88+
// benchmarks reveal that it is faster to swap temporaries all the way through instead
89+
// of reading one temporary once, copying backwards, and then writing that temporary at
90+
// the very end. This is possibly due to the fact that swapping or replacing temporaries
91+
// uses only one memory address in the loop instead of needing to manage two.
92+
loop {
93+
tmp = x.add(i).replace(tmp);
94+
// instead of incrementing `i` and then checking if it is outside the bounds, we
95+
// check if `i` will go outside the bounds on the next increment. This prevents
96+
// any wrapping of pointers or `usize`.
97+
if i >= left {
98+
i -= left;
99+
if i == 0 {
100+
// end of first round
101+
x.write(tmp);
102+
break;
103+
}
104+
// this conditional must be here if `left + right >= 15`
105+
if i < gcd {
106+
gcd = i;
107+
}
108+
} else {
109+
i += right;
110+
}
111+
}
112+
// finish the chunk with more rounds
113+
for start in 1..gcd {
114+
tmp = x.add(start).read();
115+
i = start + right;
116+
loop {
117+
tmp = x.add(i).replace(tmp);
118+
if i >= left {
119+
i -= left;
120+
if i == start {
121+
x.add(start).write(tmp);
122+
break;
123+
}
124+
} else {
125+
i += right;
126+
}
127+
}
128+
}
129+
return;
130+
// `T` is not a zero-sized type, so it's okay to divide by its size.
131+
} else if cmp::min(left, right) <= mem::size_of::<BufType>() / mem::size_of::<T>() {
132+
// Algorithm 2
133+
// The `[T; 0]` here is to ensure this is appropriately aligned for T
134+
let mut rawarray = MaybeUninit::<(BufType, [T; 0])>::uninit();
135+
let buf = rawarray.as_mut_ptr() as *mut T;
136+
let dim = mid.sub(left).add(right);
137+
if left <= right {
138+
ptr::copy_nonoverlapping(mid.sub(left), buf, left);
139+
ptr::copy(mid, mid.sub(left), right);
140+
ptr::copy_nonoverlapping(buf, dim, left);
141+
} else {
142+
ptr::copy_nonoverlapping(mid, buf, right);
143+
ptr::copy(mid.sub(left), dim, left);
144+
ptr::copy_nonoverlapping(buf, mid.sub(left), right);
145+
}
146+
return;
147+
} else if left >= right {
148+
// Algorithm 3
149+
// There is an alternate way of swapping that involves finding where the last swap
150+
// of this algorithm would be, and swapping using that last chunk instead of swapping
151+
// adjacent chunks like this algorithm is doing, but this way is still faster.
152+
loop {
153+
ptr::swap_nonoverlapping(mid.sub(right), mid, right);
154+
mid = mid.sub(right);
155+
left -= right;
156+
if left < right {
157+
break;
158+
}
159+
}
70160
} else {
71-
left -= delta;
161+
// Algorithm 3, `left < right`
162+
loop {
163+
ptr::swap_nonoverlapping(mid.sub(left), mid, left);
164+
mid = mid.add(left);
165+
right -= left;
166+
if right < left {
167+
break;
168+
}
169+
}
72170
}
73171
}
74-
75-
let mut rawarray = MaybeUninit::<RawArray<T>>::uninit();
76-
let buf = &mut (*rawarray.as_mut_ptr()).typed as *mut [T; 2] as *mut T;
77-
78-
let dim = mid.sub(left).add(right);
79-
if left <= right {
80-
ptr::copy_nonoverlapping(mid.sub(left), buf, left);
81-
ptr::copy(mid, mid.sub(left), right);
82-
ptr::copy_nonoverlapping(buf, dim, left);
83-
}
84-
else {
85-
ptr::copy_nonoverlapping(mid, buf, right);
86-
ptr::copy(mid.sub(left), dim, left);
87-
ptr::copy_nonoverlapping(buf, mid.sub(left), right);
88-
}
89172
}

src/libcore/tests/slice.rs

+38
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,44 @@ fn test_rotate_right() {
11301130
}
11311131
}
11321132

1133+
#[test]
1134+
#[cfg(not(miri))]
1135+
fn brute_force_rotate_test_0() {
1136+
// In case of edge cases involving multiple algorithms
1137+
let n = 300;
1138+
for len in 0..n {
1139+
for s in 0..len {
1140+
let mut v = Vec::with_capacity(len);
1141+
for i in 0..len {
1142+
v.push(i);
1143+
}
1144+
v[..].rotate_right(s);
1145+
for i in 0..v.len() {
1146+
assert_eq!(v[i], v.len().wrapping_add(i.wrapping_sub(s)) % v.len());
1147+
}
1148+
}
1149+
}
1150+
}
1151+
1152+
#[test]
1153+
fn brute_force_rotate_test_1() {
1154+
// `ptr_rotate` covers so many kinds of pointer usage, that this is just a good test for
1155+
// pointers in general. This uses a `[usize; 4]` to hit all algorithms without overwhelming miri
1156+
let n = 30;
1157+
for len in 0..n {
1158+
for s in 0..len {
1159+
let mut v: Vec<[usize; 4]> = Vec::with_capacity(len);
1160+
for i in 0..len {
1161+
v.push([i, 0, 0, 0]);
1162+
}
1163+
v[..].rotate_right(s);
1164+
for i in 0..v.len() {
1165+
assert_eq!(v[i][0], v.len().wrapping_add(i.wrapping_sub(s)) % v.len());
1166+
}
1167+
}
1168+
}
1169+
}
1170+
11331171
#[test]
11341172
#[cfg(not(target_arch = "wasm32"))]
11351173
fn sort_unstable() {

0 commit comments

Comments
 (0)