Skip to content

Commit c55b9d3

Browse files
committed
Make u8x16 and u8x32 have Vector call ABI
Before this commit, u8x16 and u8x32 were repr(Rust) unions. This introduced unspecified behavior because the field offsets of repr(Rust) unions are not guaranteed to be at offset 0, so that field access was potentially UB. This commit fixes that, and closes rust-lang#588 . The unions were also generating a lot of unnecessary memory operations. This commit fixes that as well. The issue is that unions have an Aggregate call ABI, which is the same as the call ABI of arrays. That is, they are passed around by memory, and not in Vector registers. This is good, if most of the time one operates on them as arrays. This was, however, not the case. Most of the operations on these unions are using SIMD instructions. This means that the union needs to be copied into a SIMD register, operated on, and then spilled back to the stack, on every single operation. That's unnecessary, although apparently LLVM was able to optimize all the unnecessary memory operations away and leave these always in registers. This commit fixes this issue as well, by making the u8x16 and u8x32 repr(transparent) newtypes over the architecture specific vector types, giving them the Vector ABI. The vectors are then copied to the stack only when necessary, and as little as possible. This is done using mem::transmute, removing the need for unions altogether (fixing rust-lang#588 by not having to worry about union layout at all). To make it clear when the vectors are spilled into the stack, the vector::replace(index, value) API has been removed, and instead, only a vector::bytes(self) and a vector::from_bytes(&mut self, [u8; N]) APIs are provided instead. This prevents spilling the vectors back and forth onto the stack every time an index needs to be modified, by using vector::bytes to spill the vector to the stack once, making all the random-access modifications in memory, and then using vector::from_bytes only once to move the memory back into a SIMD register.
1 parent 172898a commit c55b9d3

File tree

4 files changed

+82
-60
lines changed

4 files changed

+82
-60
lines changed

src/literal/teddy_avx2/imp.rs

+12-4
Original file line numberDiff line numberDiff line change
@@ -462,11 +462,19 @@ impl Mask {
462462
let byte_hi = (byte >> 4) as usize;
463463

464464
let lo = self.lo.extract(byte_lo) | ((1 << bucket) as u8);
465-
self.lo.replace(byte_lo, lo);
466-
self.lo.replace(byte_lo + 16, lo);
465+
{
466+
let mut lo_bytes = self.lo.bytes();
467+
lo_bytes[byte_lo] = lo;
468+
lo_bytes[byte_lo + 16] = lo;
469+
self.lo.replace_bytes(lo_bytes);
470+
}
467471

468472
let hi = self.hi.extract(byte_hi) | ((1 << bucket) as u8);
469-
self.hi.replace(byte_hi, hi);
470-
self.hi.replace(byte_hi + 16, hi);
473+
{
474+
let mut hi_bytes = self.hi.bytes();
475+
hi_bytes[byte_hi] = hi;
476+
hi_bytes[byte_hi + 16] = hi;
477+
self.hi.replace_bytes(hi_bytes);
478+
}
471479
}
472480
}

src/literal/teddy_ssse3/imp.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -772,9 +772,17 @@ impl Mask {
772772
let byte_hi = (byte >> 4) as usize;
773773

774774
let lo = self.lo.extract(byte_lo);
775-
self.lo.replace(byte_lo, ((1 << bucket) as u8) | lo);
775+
{
776+
let mut lo_bytes = self.lo.bytes();
777+
lo_bytes[byte_lo] = ((1 << bucket) as u8) | lo;
778+
self.lo.replace_bytes(lo_bytes);
779+
}
776780

777781
let hi = self.hi.extract(byte_hi);
778-
self.hi.replace(byte_hi, ((1 << bucket) as u8) | hi);
782+
{
783+
let mut hi_bytes = self.hi.bytes();
784+
hi_bytes[byte_hi] = ((1 << bucket) as u8) | hi;
785+
self.hi.replace_bytes(hi_bytes);
786+
}
779787
}
780788
}

src/vector/avx2.rs

+32-29
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
use std::arch::x86_64::*;
44
use std::fmt;
5+
use std::mem;
56

67
#[derive(Clone, Copy, Debug)]
78
pub struct AVX2VectorBuilder(());
@@ -56,15 +57,13 @@ impl AVX2VectorBuilder {
5657

5758
#[derive(Clone, Copy)]
5859
#[allow(non_camel_case_types)]
59-
pub union u8x32 {
60-
vector: __m256i,
61-
bytes: [u8; 32],
62-
}
60+
#[repr(transparent)]
61+
pub struct u8x32(__m256i);
6362

6463
impl u8x32 {
6564
#[inline]
6665
unsafe fn splat(n: u8) -> u8x32 {
67-
u8x32 { vector: _mm256_set1_epi8(n as i8) }
66+
u8x32(_mm256_set1_epi8(n as i8))
6867
}
6968

7069
#[inline]
@@ -76,7 +75,7 @@ impl u8x32 {
7675
#[inline]
7776
unsafe fn load_unchecked_unaligned(slice: &[u8]) -> u8x32 {
7877
let p = slice.as_ptr() as *const u8 as *const __m256i;
79-
u8x32 { vector: _mm256_loadu_si256(p) }
78+
u8x32(_mm256_loadu_si256(p))
8079
}
8180

8281
#[inline]
@@ -89,52 +88,45 @@ impl u8x32 {
8988
#[inline]
9089
unsafe fn load_unchecked(slice: &[u8]) -> u8x32 {
9190
let p = slice.as_ptr() as *const u8 as *const __m256i;
92-
u8x32 { vector: _mm256_load_si256(p) }
91+
u8x32(_mm256_load_si256(p))
9392
}
9493

9594
#[inline]
9695
pub fn extract(self, i: usize) -> u8 {
97-
// Safe because `bytes` is always accessible.
98-
unsafe { self.bytes[i] }
99-
}
100-
101-
#[inline]
102-
pub fn replace(&mut self, i: usize, byte: u8) {
103-
// Safe because `bytes` is always accessible.
104-
unsafe { self.bytes[i] = byte; }
96+
self.bytes()[i]
10597
}
10698

10799
#[inline]
108100
pub fn shuffle(self, indices: u8x32) -> u8x32 {
109101
// Safe because we know AVX2 is enabled.
110102
unsafe {
111-
u8x32 { vector: _mm256_shuffle_epi8(self.vector, indices.vector) }
103+
u8x32(_mm256_shuffle_epi8(self.0, indices.0))
112104
}
113105
}
114106

115107
#[inline]
116108
pub fn ne(self, other: u8x32) -> u8x32 {
117109
// Safe because we know AVX2 is enabled.
118110
unsafe {
119-
let boolv = _mm256_cmpeq_epi8(self.vector, other.vector);
111+
let boolv = _mm256_cmpeq_epi8(self.0, other.0);
120112
let ones = _mm256_set1_epi8(0xFF as u8 as i8);
121-
u8x32 { vector: _mm256_andnot_si256(boolv, ones) }
113+
u8x32(_mm256_andnot_si256(boolv, ones))
122114
}
123115
}
124116

125117
#[inline]
126118
pub fn and(self, other: u8x32) -> u8x32 {
127119
// Safe because we know AVX2 is enabled.
128120
unsafe {
129-
u8x32 { vector: _mm256_and_si256(self.vector, other.vector) }
121+
u8x32(_mm256_and_si256(self.0, other.0))
130122
}
131123
}
132124

133125
#[inline]
134126
pub fn movemask(self) -> u32 {
135127
// Safe because we know AVX2 is enabled.
136128
unsafe {
137-
_mm256_movemask_epi8(self.vector) as u32
129+
_mm256_movemask_epi8(self.0) as u32
138130
}
139131
}
140132

@@ -148,9 +140,9 @@ impl u8x32 {
148140
// TL;DR avx2's PALIGNR instruction is actually just two 128-bit
149141
// PALIGNR instructions, which is not what we want, so we need to
150142
// do some extra shuffling.
151-
let v = _mm256_permute2x128_si256(other.vector, self.vector, 0x21);
152-
let v = _mm256_alignr_epi8(self.vector, v, 14);
153-
u8x32 { vector: v }
143+
let v = _mm256_permute2x128_si256(other.0, self.0, 0x21);
144+
let v = _mm256_alignr_epi8(self.0, v, 14);
145+
u8x32(v)
154146
}
155147
}
156148

@@ -164,24 +156,35 @@ impl u8x32 {
164156
// TL;DR avx2's PALIGNR instruction is actually just two 128-bit
165157
// PALIGNR instructions, which is not what we want, so we need to
166158
// do some extra shuffling.
167-
let v = _mm256_permute2x128_si256(other.vector, self.vector, 0x21);
168-
let v = _mm256_alignr_epi8(self.vector, v, 15);
169-
u8x32 { vector: v }
159+
let v = _mm256_permute2x128_si256(other.0, self.0, 0x21);
160+
let v = _mm256_alignr_epi8(self.0, v, 15);
161+
u8x32(v)
170162
}
171163
}
172164

173165
#[inline]
174166
pub fn bit_shift_right_4(self) -> u8x32 {
175167
// Safe because we know AVX2 is enabled.
176168
unsafe {
177-
u8x32 { vector: _mm256_srli_epi16(self.vector, 4) }
169+
u8x32(_mm256_srli_epi16(self.0, 4))
178170
}
179171
}
172+
173+
#[inline]
174+
pub fn bytes(self) -> [u8; 32] {
175+
// Safe because __m256i and [u8; 32] are layout compatible
176+
unsafe { mem::transmute(self) }
177+
}
178+
179+
#[inline]
180+
pub fn replace_bytes(&mut self, value: [u8; 32]) {
181+
// Safe because __m256i and [u8; 32] are layout compatible
182+
self.0 = unsafe { mem::transmute(value) };
183+
}
180184
}
181185

182186
impl fmt::Debug for u8x32 {
183187
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
184-
// Safe because `bytes` is always accessible.
185-
unsafe { self.bytes.fmt(f) }
188+
self.bytes().fmt(f)
186189
}
187190
}

src/vector/ssse3.rs

+28-25
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
use std::arch::x86_64::*;
44
use std::fmt;
5+
use std::mem;
56

67
/// A builder for SSSE3 empowered vectors.
78
///
@@ -77,15 +78,13 @@ impl SSSE3VectorBuilder {
7778
/// inlined, otherwise you probably have a performance bug.
7879
#[derive(Clone, Copy)]
7980
#[allow(non_camel_case_types)]
80-
pub union u8x16 {
81-
vector: __m128i,
82-
bytes: [u8; 16],
83-
}
81+
#[repr(transparent)]
82+
pub struct u8x16(__m128i);
8483

8584
impl u8x16 {
8685
#[inline]
8786
unsafe fn splat(n: u8) -> u8x16 {
88-
u8x16 { vector: _mm_set1_epi8(n as i8) }
87+
u8x16(_mm_set1_epi8(n as i8))
8988
}
9089

9190
#[inline]
@@ -97,7 +96,7 @@ impl u8x16 {
9796
#[inline]
9897
unsafe fn load_unchecked_unaligned(slice: &[u8]) -> u8x16 {
9998
let v = _mm_loadu_si128(slice.as_ptr() as *const u8 as *const __m128i);
100-
u8x16 { vector: v }
99+
u8x16(v)
101100
}
102101

103102
#[inline]
@@ -110,83 +109,87 @@ impl u8x16 {
110109
#[inline]
111110
unsafe fn load_unchecked(slice: &[u8]) -> u8x16 {
112111
let v = _mm_load_si128(slice.as_ptr() as *const u8 as *const __m128i);
113-
u8x16 { vector: v }
112+
u8x16(v)
114113
}
115114

116115
#[inline]
117116
pub fn extract(self, i: usize) -> u8 {
118-
// Safe because `bytes` is always accessible.
119-
unsafe { self.bytes[i] }
120-
}
121-
122-
#[inline]
123-
pub fn replace(&mut self, i: usize, byte: u8) {
124-
// Safe because `bytes` is always accessible.
125-
unsafe { self.bytes[i] = byte; }
117+
self.bytes()[i]
126118
}
127119

128120
#[inline]
129121
pub fn shuffle(self, indices: u8x16) -> u8x16 {
130122
// Safe because we know SSSE3 is enabled.
131123
unsafe {
132-
u8x16 { vector: _mm_shuffle_epi8(self.vector, indices.vector) }
124+
u8x16(_mm_shuffle_epi8(self.0, indices.0))
133125
}
134126
}
135127

136128
#[inline]
137129
pub fn ne(self, other: u8x16) -> u8x16 {
138130
// Safe because we know SSSE3 is enabled.
139131
unsafe {
140-
let boolv = _mm_cmpeq_epi8(self.vector, other.vector);
132+
let boolv = _mm_cmpeq_epi8(self.0, other.0);
141133
let ones = _mm_set1_epi8(0xFF as u8 as i8);
142-
u8x16 { vector: _mm_andnot_si128(boolv, ones) }
134+
u8x16(_mm_andnot_si128(boolv, ones))
143135
}
144136
}
145137

146138
#[inline]
147139
pub fn and(self, other: u8x16) -> u8x16 {
148140
// Safe because we know SSSE3 is enabled.
149141
unsafe {
150-
u8x16 { vector: _mm_and_si128(self.vector, other.vector) }
142+
u8x16(_mm_and_si128(self.0, other.0))
151143
}
152144
}
153145

154146
#[inline]
155147
pub fn movemask(self) -> u32 {
156148
// Safe because we know SSSE3 is enabled.
157149
unsafe {
158-
_mm_movemask_epi8(self.vector) as u32
150+
_mm_movemask_epi8(self.0) as u32
159151
}
160152
}
161153

162154
#[inline]
163155
pub fn alignr_14(self, other: u8x16) -> u8x16 {
164156
// Safe because we know SSSE3 is enabled.
165157
unsafe {
166-
u8x16 { vector: _mm_alignr_epi8(self.vector, other.vector, 14) }
158+
u8x16(_mm_alignr_epi8(self.0, other.0, 14))
167159
}
168160
}
169161

170162
#[inline]
171163
pub fn alignr_15(self, other: u8x16) -> u8x16 {
172164
// Safe because we know SSSE3 is enabled.
173165
unsafe {
174-
u8x16 { vector: _mm_alignr_epi8(self.vector, other.vector, 15) }
166+
u8x16(_mm_alignr_epi8(self.0, other.0, 15))
175167
}
176168
}
177169

178170
#[inline]
179171
pub fn bit_shift_right_4(self) -> u8x16 {
180172
// Safe because we know SSSE3 is enabled.
181173
unsafe {
182-
u8x16 { vector: _mm_srli_epi16(self.vector, 4) }
174+
u8x16(_mm_srli_epi16(self.0, 4))
183175
}
184176
}
177+
178+
#[inline]
179+
pub fn bytes(self) -> [u8; 16] {
180+
// Safe because __m128i and [u8; 16] are layout compatible
181+
unsafe { mem::transmute(self) }
182+
}
183+
184+
#[inline]
185+
pub fn replace_bytes(&mut self, value: [u8; 16]) {
186+
// Safe because __m128i and [u8; 16] are layout compatible
187+
self.0 = unsafe { mem::transmute(value) };
188+
}
185189
}
186190

187191
impl fmt::Debug for u8x16 {
188192
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
189-
// Safe because `bytes` is always accessible.
190-
unsafe { self.bytes.fmt(f) }
193+
self.bytes().fmt(f)
191194
}
192195
}

0 commit comments

Comments
 (0)