Skip to content

Commit 2ccc054

Browse files
authored
Rollup merge of rust-lang#59283 - SimonSapin:branchless-ascii-case, r=joshtriplett
Make ASCII case conversions more than 4× faster Reformatted output of `./x.py bench src/libcore --test-args ascii` below. The `libcore` benchmark calls `[u8]::make_ascii_lowercase`. `lookup` has code (effectively) identical to that before this PR, and ~~`branchless`~~ `mask_shifted_bool_match_range` after this PR. ~~See [code comments](rust-lang@ce933f7#diff-01076f91a26400b2db49663d787c2576R3796) in `u8::to_ascii_uppercase` in `src/libcore/num/mod.rs` for an explanation of the branchless algorithm.~~ **Update:** the algorithm was simplified while keeping the performance. See `branchless` v.s. `mask_shifted_bool_match_range` benchmarks. Credits to @raphlinus for the idea in https://twitter.com/raphlinus/status/1107654782544736261, which extends this algorithm to “fake SIMD” on `u32` to convert four bytes at a time. The `fake_simd_u32` benchmarks implements this with [`let (before, aligned, after) = bytes.align_to_mut::<u32>()`](https://doc.rust-lang.org/std/primitive.slice.html#method.align_to_mut). Note however that this is buggy when addition carries/overflows into the next byte (which does not happen if the input is known to be ASCII). This could be fixed (to optimize `[u8]::make_ascii_lowercase` and `[u8]::make_ascii_uppercase` in `src/libcore/slice/mod.rs`) either with some more bitwise trickery that I didn’t quite figure out, or by using “real” SIMD intrinsics for byte-wise addition. I did not pursue this however because the current (incorrect) fake SIMD algorithm is only marginally faster than the one-byte-at-a-time branchless algorithm. This is because LLVM auto-vectorizes the latter, as can be seen on https://rust.godbolt.org/z/anKtbR. Benchmark results on Linux x64 with Intel i7-7700K: (updated from rust-lang#59283 (comment)) ```rust 6830 bytes string: alloc_only ... bench: 112 ns/iter (+/- 0) = 62410 MB/s black_box_read_each_byte ... bench: 1,733 ns/iter (+/- 8) = 4033 MB/s lookup_table ... bench: 1,766 ns/iter (+/- 11) = 3958 MB/s branch_and_subtract ... bench: 417 ns/iter (+/- 1) = 16762 MB/s branch_and_mask ... bench: 401 ns/iter (+/- 1) = 17431 MB/s branchless ... bench: 365 ns/iter (+/- 0) = 19150 MB/s libcore ... bench: 367 ns/iter (+/- 1) = 19046 MB/s fake_simd_u32 ... bench: 361 ns/iter (+/- 2) = 19362 MB/s fake_simd_u64 ... bench: 361 ns/iter (+/- 1) = 19362 MB/s mask_mult_bool_branchy_lookup_table ... bench: 6,309 ns/iter (+/- 19) = 1107 MB/s mask_mult_bool_lookup_table ... bench: 4,183 ns/iter (+/- 29) = 1671 MB/s mask_mult_bool_match_range ... bench: 339 ns/iter (+/- 0) = 20619 MB/s mask_shifted_bool_match_range ... bench: 339 ns/iter (+/- 1) = 20619 MB/s 32 bytes string: alloc_only ... bench: 15 ns/iter (+/- 0) = 2133 MB/s black_box_read_each_byte ... bench: 29 ns/iter (+/- 0) = 1103 MB/s lookup_table ... bench: 24 ns/iter (+/- 4) = 1333 MB/s branch_and_subtract ... bench: 16 ns/iter (+/- 0) = 2000 MB/s branch_and_mask ... bench: 16 ns/iter (+/- 0) = 2000 MB/s branchless ... bench: 16 ns/iter (+/- 0) = 2000 MB/s libcore ... bench: 15 ns/iter (+/- 0) = 2133 MB/s fake_simd_u32 ... bench: 17 ns/iter (+/- 0) = 1882 MB/s fake_simd_u64 ... bench: 16 ns/iter (+/- 0) = 2000 MB/s mask_mult_bool_branchy_lookup_table ... bench: 42 ns/iter (+/- 0) = 761 MB/s mask_mult_bool_lookup_table ... bench: 35 ns/iter (+/- 0) = 914 MB/s mask_mult_bool_match_range ... bench: 16 ns/iter (+/- 0) = 2000 MB/s mask_shifted_bool_match_range ... bench: 16 ns/iter (+/- 0) = 2000 MB/s 7 bytes string: alloc_only ... bench: 14 ns/iter (+/- 0) = 500 MB/s black_box_read_each_byte ... bench: 22 ns/iter (+/- 0) = 318 MB/s lookup_table ... bench: 16 ns/iter (+/- 0) = 437 MB/s branch_and_subtract ... bench: 16 ns/iter (+/- 0) = 437 MB/s branch_and_mask ... bench: 16 ns/iter (+/- 0) = 437 MB/s branchless ... bench: 19 ns/iter (+/- 0) = 368 MB/s libcore ... bench: 20 ns/iter (+/- 0) = 350 MB/s fake_simd_u32 ... bench: 18 ns/iter (+/- 0) = 388 MB/s fake_simd_u64 ... bench: 21 ns/iter (+/- 0) = 333 MB/s mask_mult_bool_branchy_lookup_table ... bench: 20 ns/iter (+/- 0) = 350 MB/s mask_mult_bool_lookup_table ... bench: 19 ns/iter (+/- 0) = 368 MB/s mask_mult_bool_match_range ... bench: 19 ns/iter (+/- 0) = 368 MB/s mask_shifted_bool_match_range ... bench: 19 ns/iter (+/- 0) = 368 MB/s ```
2 parents dfce933 + 7fad370 commit 2ccc054

File tree

3 files changed

+374
-135
lines changed

3 files changed

+374
-135
lines changed

src/libcore/benches/ascii.rs

+349
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
// Lower-case ASCII 'a' is the first byte that has its highest bit set
2+
// after wrap-adding 0x1F:
3+
//
4+
// b'a' + 0x1F == 0x80 == 0b1000_0000
5+
// b'z' + 0x1F == 0x98 == 0b10011000
6+
//
7+
// Lower-case ASCII 'z' is the last byte that has its highest bit unset
8+
// after wrap-adding 0x05:
9+
//
10+
// b'a' + 0x05 == 0x66 == 0b0110_0110
11+
// b'z' + 0x05 == 0x7F == 0b0111_1111
12+
//
13+
// … except for 0xFB to 0xFF, but those are in the range of bytes
14+
// that have the highest bit unset again after adding 0x1F.
15+
//
16+
// So `(byte + 0x1f) & !(byte + 5)` has its highest bit set
17+
// iff `byte` is a lower-case ASCII letter.
18+
//
19+
// Lower-case ASCII letters all have the 0x20 bit set.
20+
// (Two positions right of 0x80, the highest bit.)
21+
// Unsetting that bit produces the same letter, in upper-case.
22+
//
23+
// Therefore:
24+
fn branchless_to_ascii_upper_case(byte: u8) -> u8 {
25+
byte &
26+
!(
27+
(
28+
byte.wrapping_add(0x1f) &
29+
!byte.wrapping_add(0x05) &
30+
0x80
31+
) >> 2
32+
)
33+
}
34+
35+
36+
macro_rules! benches {
37+
($( fn $name: ident($arg: ident: &mut [u8]) $body: block )+ @iter $( $is_: ident, )+) => {
38+
benches! {@
39+
$( fn $name($arg: &mut [u8]) $body )+
40+
$( fn $is_(bytes: &mut [u8]) { bytes.iter().all(u8::$is_) } )+
41+
}
42+
};
43+
44+
(@$( fn $name: ident($arg: ident: &mut [u8]) $body: block )+) => {
45+
benches!(mod short SHORT $($name $arg $body)+);
46+
benches!(mod medium MEDIUM $($name $arg $body)+);
47+
benches!(mod long LONG $($name $arg $body)+);
48+
};
49+
50+
(mod $mod_name: ident $input: ident $($name: ident $arg: ident $body: block)+) => {
51+
mod $mod_name {
52+
use super::*;
53+
54+
$(
55+
#[bench]
56+
fn $name(bencher: &mut Bencher) {
57+
bencher.bytes = $input.len() as u64;
58+
bencher.iter(|| {
59+
let mut vec = $input.as_bytes().to_vec();
60+
{
61+
let $arg = &mut vec[..];
62+
black_box($body);
63+
}
64+
vec
65+
})
66+
}
67+
)+
68+
}
69+
}
70+
}
71+
72+
use test::black_box;
73+
use test::Bencher;
74+
75+
benches! {
76+
fn case00_alloc_only(_bytes: &mut [u8]) {}
77+
78+
fn case01_black_box_read_each_byte(bytes: &mut [u8]) {
79+
for byte in bytes {
80+
black_box(*byte);
81+
}
82+
}
83+
84+
fn case02_lookup_table(bytes: &mut [u8]) {
85+
for byte in bytes {
86+
*byte = ASCII_UPPERCASE_MAP[*byte as usize]
87+
}
88+
}
89+
90+
fn case03_branch_and_subtract(bytes: &mut [u8]) {
91+
for byte in bytes {
92+
*byte = if b'a' <= *byte && *byte <= b'z' {
93+
*byte - b'a' + b'A'
94+
} else {
95+
*byte
96+
}
97+
}
98+
}
99+
100+
fn case04_branch_and_mask(bytes: &mut [u8]) {
101+
for byte in bytes {
102+
*byte = if b'a' <= *byte && *byte <= b'z' {
103+
*byte & !0x20
104+
} else {
105+
*byte
106+
}
107+
}
108+
}
109+
110+
fn case05_branchless(bytes: &mut [u8]) {
111+
for byte in bytes {
112+
*byte = branchless_to_ascii_upper_case(*byte)
113+
}
114+
}
115+
116+
fn case06_libcore(bytes: &mut [u8]) {
117+
bytes.make_ascii_uppercase()
118+
}
119+
120+
fn case07_fake_simd_u32(bytes: &mut [u8]) {
121+
let (before, aligned, after) = unsafe {
122+
bytes.align_to_mut::<u32>()
123+
};
124+
for byte in before {
125+
*byte = branchless_to_ascii_upper_case(*byte)
126+
}
127+
for word in aligned {
128+
// FIXME: this is incorrect for some byte values:
129+
// addition within a byte can carry/overflow into the next byte.
130+
// Test case: b"\xFFz "
131+
*word &= !(
132+
(
133+
word.wrapping_add(0x1f1f1f1f) &
134+
!word.wrapping_add(0x05050505) &
135+
0x80808080
136+
) >> 2
137+
)
138+
}
139+
for byte in after {
140+
*byte = branchless_to_ascii_upper_case(*byte)
141+
}
142+
}
143+
144+
fn case08_fake_simd_u64(bytes: &mut [u8]) {
145+
let (before, aligned, after) = unsafe {
146+
bytes.align_to_mut::<u64>()
147+
};
148+
for byte in before {
149+
*byte = branchless_to_ascii_upper_case(*byte)
150+
}
151+
for word in aligned {
152+
// FIXME: like above, this is incorrect for some byte values.
153+
*word &= !(
154+
(
155+
word.wrapping_add(0x1f1f1f1f_1f1f1f1f) &
156+
!word.wrapping_add(0x05050505_05050505) &
157+
0x80808080_80808080
158+
) >> 2
159+
)
160+
}
161+
for byte in after {
162+
*byte = branchless_to_ascii_upper_case(*byte)
163+
}
164+
}
165+
166+
fn case09_mask_mult_bool_branchy_lookup_table(bytes: &mut [u8]) {
167+
fn is_ascii_lowercase(b: u8) -> bool {
168+
if b >= 0x80 { return false }
169+
match ASCII_CHARACTER_CLASS[b as usize] {
170+
L | Lx => true,
171+
_ => false,
172+
}
173+
}
174+
for byte in bytes {
175+
*byte &= !(0x20 * (is_ascii_lowercase(*byte) as u8))
176+
}
177+
}
178+
179+
fn case10_mask_mult_bool_lookup_table(bytes: &mut [u8]) {
180+
fn is_ascii_lowercase(b: u8) -> bool {
181+
match ASCII_CHARACTER_CLASS[b as usize] {
182+
L | Lx => true,
183+
_ => false
184+
}
185+
}
186+
for byte in bytes {
187+
*byte &= !(0x20 * (is_ascii_lowercase(*byte) as u8))
188+
}
189+
}
190+
191+
fn case11_mask_mult_bool_match_range(bytes: &mut [u8]) {
192+
fn is_ascii_lowercase(b: u8) -> bool {
193+
match b {
194+
b'a'...b'z' => true,
195+
_ => false
196+
}
197+
}
198+
for byte in bytes {
199+
*byte &= !(0x20 * (is_ascii_lowercase(*byte) as u8))
200+
}
201+
}
202+
203+
fn case12_mask_shifted_bool_match_range(bytes: &mut [u8]) {
204+
fn is_ascii_lowercase(b: u8) -> bool {
205+
match b {
206+
b'a'...b'z' => true,
207+
_ => false
208+
}
209+
}
210+
for byte in bytes {
211+
*byte &= !((is_ascii_lowercase(*byte) as u8) << 5)
212+
}
213+
}
214+
215+
fn case13_subtract_shifted_bool_match_range(bytes: &mut [u8]) {
216+
fn is_ascii_lowercase(b: u8) -> bool {
217+
match b {
218+
b'a'...b'z' => true,
219+
_ => false
220+
}
221+
}
222+
for byte in bytes {
223+
*byte -= (is_ascii_lowercase(*byte) as u8) << 5
224+
}
225+
}
226+
227+
fn case14_subtract_multiplied_bool_match_range(bytes: &mut [u8]) {
228+
fn is_ascii_lowercase(b: u8) -> bool {
229+
match b {
230+
b'a'...b'z' => true,
231+
_ => false
232+
}
233+
}
234+
for byte in bytes {
235+
*byte -= (b'a' - b'A') * is_ascii_lowercase(*byte) as u8
236+
}
237+
}
238+
239+
@iter
240+
241+
is_ascii,
242+
is_ascii_alphabetic,
243+
is_ascii_uppercase,
244+
is_ascii_lowercase,
245+
is_ascii_alphanumeric,
246+
is_ascii_digit,
247+
is_ascii_hexdigit,
248+
is_ascii_punctuation,
249+
is_ascii_graphic,
250+
is_ascii_whitespace,
251+
is_ascii_control,
252+
}
253+
254+
macro_rules! repeat {
255+
($s: expr) => { concat!($s, $s, $s, $s, $s, $s, $s, $s, $s, $s) }
256+
}
257+
258+
const SHORT: &'static str = "Alice's";
259+
const MEDIUM: &'static str = "Alice's Adventures in Wonderland";
260+
const LONG: &'static str = repeat!(r#"
261+
La Guida di Bragia, a Ballad Opera for the Marionette Theatre (around 1850)
262+
Alice's Adventures in Wonderland (1865)
263+
Phantasmagoria and Other Poems (1869)
264+
Through the Looking-Glass, and What Alice Found There
265+
(includes "Jabberwocky" and "The Walrus and the Carpenter") (1871)
266+
The Hunting of the Snark (1876)
267+
Rhyme? And Reason? (1883) – shares some contents with the 1869 collection,
268+
including the long poem "Phantasmagoria"
269+
A Tangled Tale (1885)
270+
Sylvie and Bruno (1889)
271+
Sylvie and Bruno Concluded (1893)
272+
Pillow Problems (1893)
273+
What the Tortoise Said to Achilles (1895)
274+
Three Sunsets and Other Poems (1898)
275+
The Manlet (1903)[106]
276+
"#);
277+
278+
const ASCII_UPPERCASE_MAP: [u8; 256] = [
279+
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
280+
0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
281+
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
282+
0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
283+
b' ', b'!', b'"', b'#', b'$', b'%', b'&', b'\'',
284+
b'(', b')', b'*', b'+', b',', b'-', b'.', b'/',
285+
b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7',
286+
b'8', b'9', b':', b';', b'<', b'=', b'>', b'?',
287+
b'@', b'A', b'B', b'C', b'D', b'E', b'F', b'G',
288+
b'H', b'I', b'J', b'K', b'L', b'M', b'N', b'O',
289+
b'P', b'Q', b'R', b'S', b'T', b'U', b'V', b'W',
290+
b'X', b'Y', b'Z', b'[', b'\\', b']', b'^', b'_',
291+
b'`',
292+
293+
b'A', b'B', b'C', b'D', b'E', b'F', b'G',
294+
b'H', b'I', b'J', b'K', b'L', b'M', b'N', b'O',
295+
b'P', b'Q', b'R', b'S', b'T', b'U', b'V', b'W',
296+
b'X', b'Y', b'Z',
297+
298+
b'{', b'|', b'}', b'~', 0x7f,
299+
0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87,
300+
0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f,
301+
0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97,
302+
0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f,
303+
0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7,
304+
0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf,
305+
0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7,
306+
0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf,
307+
0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7,
308+
0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf,
309+
0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd5, 0xd6, 0xd7,
310+
0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf,
311+
0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7,
312+
0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef,
313+
0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7,
314+
0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff,
315+
];
316+
317+
enum AsciiCharacterClass {
318+
C, // control
319+
Cw, // control whitespace
320+
W, // whitespace
321+
D, // digit
322+
L, // lowercase
323+
Lx, // lowercase hex digit
324+
U, // uppercase
325+
Ux, // uppercase hex digit
326+
P, // punctuation
327+
N, // Non-ASCII
328+
}
329+
use self::AsciiCharacterClass::*;
330+
331+
static ASCII_CHARACTER_CLASS: [AsciiCharacterClass; 256] = [
332+
// _0 _1 _2 _3 _4 _5 _6 _7 _8 _9 _a _b _c _d _e _f
333+
C, C, C, C, C, C, C, C, C, Cw,Cw,C, Cw,Cw,C, C, // 0_
334+
C, C, C, C, C, C, C, C, C, C, C, C, C, C, C, C, // 1_
335+
W, P, P, P, P, P, P, P, P, P, P, P, P, P, P, P, // 2_
336+
D, D, D, D, D, D, D, D, D, D, P, P, P, P, P, P, // 3_
337+
P, Ux,Ux,Ux,Ux,Ux,Ux,U, U, U, U, U, U, U, U, U, // 4_
338+
U, U, U, U, U, U, U, U, U, U, U, P, P, P, P, P, // 5_
339+
P, Lx,Lx,Lx,Lx,Lx,Lx,L, L, L, L, L, L, L, L, L, // 6_
340+
L, L, L, L, L, L, L, L, L, L, L, P, P, P, P, C, // 7_
341+
N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N,
342+
N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N,
343+
N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N,
344+
N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N,
345+
N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N,
346+
N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N,
347+
N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N,
348+
N, N, N, N, N, N, N, N, N, N, N, N, N, N, N, N,
349+
];

src/libcore/benches/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ extern crate core;
55
extern crate test;
66

77
mod any;
8+
mod ascii;
89
mod char;
910
mod hash;
1011
mod iter;

0 commit comments

Comments
 (0)