Skip to content

Commit 6c285d6

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Add StoreInterleaved2. Refs #641
PiperOrigin-RevId: 444281375
1 parent 9e436d9 commit 6c285d6

12 files changed

+288
-5
lines changed

g3doc/quick_reference.md

+7-3
Original file line numberDiff line numberDiff line change
@@ -790,11 +790,15 @@ F(src[tbl[i]])` because `Scatter` is more expensive than `Gather`.
790790
not fault, unlike `BlendedStore`. No alignment requirement. Potentially
791791
non-atomic, like `BlendedStore`.
792792

793+
* `D`: `u8` \
794+
<code>void **StoreInterleaved2**(Vec&lt;D&gt; v0, Vec&lt;D&gt; v1, D, T*
795+
p)</code>: equivalent to shuffling `v0, v1` followed by two `StoreU()`, such
796+
that `p[0] == v0[0], p[1] == v1[0]`.
797+
793798
* `D`: `u8` \
794799
<code>void **StoreInterleaved3**(Vec&lt;D&gt; v0, Vec&lt;D&gt; v1,
795-
Vec&lt;D&gt; v2, D, T* p)</code>: equivalent to shuffling `v0, v1, v2`
796-
followed by three `StoreU()`, such that `p[0] == v0[0], p[1] == v1[0],
797-
p[2] == v1[0]`. Useful for RGB samples.
800+
Vec&lt;D&gt; v2, D, T* p)</code>: as above, but for three vectors (e.g. RGB
801+
samples).
798802

799803
* `D`: `u8` \
800804
<code>void **StoreInterleaved4**(Vec&lt;D&gt; v0, Vec&lt;D&gt; v1,

hwy/ops/arm_neon-inl.h

+31
Original file line numberDiff line numberDiff line change
@@ -5334,6 +5334,37 @@ HWY_API size_t CompressBitsStore(Vec128<T, N> v,
53345334
return PopCount(mask_bits);
53355335
}
53365336

5337+
// ------------------------------ StoreInterleaved2
5338+
5339+
// 128 bits
5340+
HWY_API void StoreInterleaved2(const Vec128<uint8_t> v0,
5341+
const Vec128<uint8_t> v1,
5342+
Full128<uint8_t> /*tag*/,
5343+
uint8_t* HWY_RESTRICT unaligned) {
5344+
const uint8x16x2_t pair = {{v0.raw, v1.raw}};
5345+
vst2q_u8(unaligned, pair);
5346+
}
5347+
5348+
// 64 bits
5349+
HWY_API void StoreInterleaved2(const Vec64<uint8_t> v0, const Vec64<uint8_t> v1,
5350+
Full64<uint8_t> /*tag*/,
5351+
uint8_t* HWY_RESTRICT unaligned) {
5352+
const uint8x8x2_t pair = {{v0.raw, v1.raw}};
5353+
vst2_u8(unaligned, pair);
5354+
}
5355+
5356+
// <= 32 bits: avoid writing more than N bytes by copying to buffer
5357+
template <size_t N, HWY_IF_LE32(uint8_t, N)>
5358+
HWY_API void StoreInterleaved2(const Vec128<uint8_t, N> v0,
5359+
const Vec128<uint8_t, N> v1,
5360+
Simd<uint8_t, N, 0> /*tag*/,
5361+
uint8_t* HWY_RESTRICT unaligned) {
5362+
alignas(16) uint8_t buf[16];
5363+
const uint8x8x2_t pair = {{v0.raw, v1.raw}};
5364+
vst2_u8(buf, pair);
5365+
CopyBytes<N * 2>(buf, unaligned);
5366+
}
5367+
53375368
// ------------------------------ StoreInterleaved3
53385369

53395370
// 128 bits

hwy/ops/arm_sve-inl.h

+14
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,20 @@ HWY_SVE_FOREACH_UIF3264(HWY_SVE_GATHER_INDEX, GatherIndex, ld1_gather)
10091009
#undef HWY_SVE_GATHER_OFFSET
10101010
#undef HWY_SVE_GATHER_INDEX
10111011

1012+
// ------------------------------ StoreInterleaved2
1013+
1014+
#define HWY_SVE_STORE2(BASE, CHAR, BITS, HALF, NAME, OP) \
1015+
template <size_t N, int kPow2> \
1016+
HWY_API void NAME(HWY_SVE_V(BASE, BITS) v0, HWY_SVE_V(BASE, BITS) v1, \
1017+
HWY_SVE_D(BASE, BITS, N, kPow2) d, \
1018+
HWY_SVE_T(BASE, BITS) * HWY_RESTRICT unaligned) { \
1019+
const sv##BASE##BITS##x2_t tuple = svcreate2##_##CHAR##BITS(v0, v1); \
1020+
sv##OP##_##CHAR##BITS(detail::MakeMask(d), unaligned, tuple); \
1021+
}
1022+
HWY_SVE_FOREACH_U08(HWY_SVE_STORE2, StoreInterleaved2, st2)
1023+
1024+
#undef HWY_SVE_STORE2
1025+
10121026
// ------------------------------ StoreInterleaved3
10131027

10141028
#define HWY_SVE_STORE3(BASE, CHAR, BITS, HALF, NAME, OP) \

hwy/ops/emu128-inl.h

+12-1
Original file line numberDiff line numberDiff line change
@@ -1201,7 +1201,18 @@ HWY_API void BlendedStore(const Vec128<T, N> v, Mask128<T, N> m,
12011201
}
12021202
}
12031203

1204-
// ------------------------------ StoreInterleaved3
1204+
// ------------------------------ StoreInterleaved2/3/4
1205+
1206+
template <size_t N>
1207+
HWY_API void StoreInterleaved2(const Vec128<uint8_t, N> v0,
1208+
const Vec128<uint8_t, N> v1,
1209+
Simd<uint8_t, N, 0> /* tag */,
1210+
uint8_t* HWY_RESTRICT unaligned) {
1211+
for (size_t i = 0; i < N; ++i) {
1212+
*unaligned++ = v0.raw[i];
1213+
*unaligned++ = v1.raw[i];
1214+
}
1215+
}
12051216

12061217
template <size_t N>
12071218
HWY_API void StoreInterleaved3(const Vec128<uint8_t, N> v0,

hwy/ops/rvv-inl.h

+24
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,30 @@ HWY_API VFromD<D> GatherIndex(D d, const TFromD<D>* HWY_RESTRICT base,
13861386
return GatherOffset(d, base, ShiftLeft<3>(index));
13871387
}
13881388

1389+
// ------------------------------ StoreInterleaved2
1390+
1391+
#define HWY_RVV_STORE2(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \
1392+
MLEN, NAME, OP) \
1393+
template <size_t N> \
1394+
HWY_API void NAME(HWY_RVV_V(BASE, SEW, LMUL) v0, \
1395+
HWY_RVV_V(BASE, SEW, LMUL) v1, \
1396+
HWY_RVV_D(BASE, SEW, N, SHIFT) d, \
1397+
HWY_RVV_T(BASE, SEW) * HWY_RESTRICT unaligned) { \
1398+
return v##OP##e8_v_##CHAR##SEW##LMUL(unaligned, v0, v1, Lanes(d)); \
1399+
}
1400+
// Segments are limited to 8 registers, so we can only go up to LMUL=2.
1401+
HWY_RVV_STORE2(uint, u, 8, _, _, mf8, _, _, /*kShift=*/-3, 64,
1402+
StoreInterleaved2, sseg2)
1403+
HWY_RVV_STORE2(uint, u, 8, _, _, mf4, _, _, /*kShift=*/-2, 32,
1404+
StoreInterleaved2, sseg2)
1405+
HWY_RVV_STORE2(uint, u, 8, _, _, mf2, _, _, /*kShift=*/-1, 16,
1406+
StoreInterleaved2, sseg2)
1407+
HWY_RVV_STORE2(uint, u, 8, _, _, m1, _, _, /*kShift=*/0, 8, StoreInterleaved2,
1408+
sseg2)
1409+
HWY_RVV_STORE2(uint, u, 8, _, _, m2, _, _, /*kShift=*/1, 4, StoreInterleaved2,
1410+
sseg2)
1411+
#undef HWY_RVV_STORE2
1412+
13891413
// ------------------------------ StoreInterleaved3
13901414

13911415
#define HWY_RVV_STORE3(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \

hwy/ops/scalar-inl.h

+8-1
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,14 @@ HWY_API void BlendedStore(const Vec1<T> v, Mask1<T> m, Sisd<T> d,
954954
StoreU(v, d, p);
955955
}
956956

957-
// ------------------------------ StoreInterleaved3
957+
// ------------------------------ StoreInterleaved2/3/4
958+
959+
HWY_API void StoreInterleaved2(const Vec1<uint8_t> v0, const Vec1<uint8_t> v1,
960+
Sisd<uint8_t> d,
961+
uint8_t* HWY_RESTRICT unaligned) {
962+
StoreU(v0, d, unaligned + 0);
963+
StoreU(v1, d, unaligned + 1);
964+
}
958965

959966
HWY_API void StoreInterleaved3(const Vec1<uint8_t> v0, const Vec1<uint8_t> v1,
960967
const Vec1<uint8_t> v2, Sisd<uint8_t> d,

hwy/ops/wasm_128-inl.h

+47
Original file line numberDiff line numberDiff line change
@@ -3767,6 +3767,53 @@ HWY_API size_t CompressBitsStore(Vec128<T, N> v,
37673767
return PopCount(mask_bits);
37683768
}
37693769

3770+
// ------------------------------ StoreInterleaved2
3771+
3772+
// 128 bits
3773+
HWY_API void StoreInterleaved2(const Vec128<uint8_t> v0,
3774+
const Vec128<uint8_t> v1, Full128<uint8_t> d8,
3775+
uint8_t* HWY_RESTRICT unaligned) {
3776+
const RepartitionToWide<decltype(d8)> d16;
3777+
// let a,b denote v0,v1.
3778+
const auto ba0 = ZipLower(d16, v0, v1); // b7 a7 .. b0 a0
3779+
const auto ba8 = ZipUpper(d16, v0, v1);
3780+
StoreU(BitCast(d8, ba0), d8, unaligned + 0 * 16);
3781+
StoreU(BitCast(d8, ba8), d8, unaligned + 1 * 16);
3782+
}
3783+
3784+
// 64 bits
3785+
HWY_API void StoreInterleaved2(const Vec64<uint8_t> in0,
3786+
const Vec64<uint8_t> in1,
3787+
Full64<uint8_t> /*tag*/,
3788+
uint8_t* HWY_RESTRICT unaligned) {
3789+
// Use full vectors to reduce the number of stores.
3790+
const Full128<uint8_t> d_full8;
3791+
const RepartitionToWide<decltype(d_full8)> d16;
3792+
const Vec128<uint8_t> v0{in0.raw};
3793+
const Vec128<uint8_t> v1{in1.raw};
3794+
// let a,b,c,d denote v0,v1.
3795+
const auto ba0 = ZipLower(d16, v0, v1); // b7 a7 .. b0 a0
3796+
StoreU(BitCast(d_full8, ba0), d_full8, unaligned + 0 * 16);
3797+
}
3798+
3799+
// <= 32 bits
3800+
template <size_t N, HWY_IF_LE32(uint8_t, N)>
3801+
HWY_API void StoreInterleaved2(const Vec128<uint8_t, N> in0,
3802+
const Vec128<uint8_t, N> in1,
3803+
Simd<uint8_t, N, 0> /*tag*/,
3804+
uint8_t* HWY_RESTRICT unaligned) {
3805+
// Use full vectors to reduce the number of stores.
3806+
const Full128<uint8_t> d_full8;
3807+
const RepartitionToWide<decltype(d_full8)> d16;
3808+
const Vec128<uint8_t> v0{in0.raw};
3809+
const Vec128<uint8_t> v1{in1.raw};
3810+
// let a,b,c,d denote v0..3.
3811+
const auto ba0 = ZipLower(d16, v0, v1); // b3 a3 .. b0 a0
3812+
alignas(16) uint8_t buf[16];
3813+
StoreU(BitCast(d_full8, ba0), d_full8, buf);
3814+
CopyBytes<4 * N>(buf, unaligned);
3815+
}
3816+
37703817
// ------------------------------ StoreInterleaved3 (CombineShiftRightBytes,
37713818
// TableLookupBytes)
37723819

hwy/ops/wasm_256-inl.h

+8
Original file line numberDiff line numberDiff line change
@@ -2808,6 +2808,14 @@ HWY_API size_t CompressBitsStore(Vec256<T> v, const uint8_t* HWY_RESTRICT bits,
28082808
return PopCount(mask_bits);
28092809
}
28102810

2811+
// ------------------------------ StoreInterleaved2
2812+
2813+
HWY_API void StoreInterleaved2(const Vec256<uint8_t> a, const Vec256<uint8_t> b,
2814+
Full256<uint8_t> d,
2815+
uint8_t* HWY_RESTRICT unaligned) {
2816+
HWY_ASSERT(0);
2817+
}
2818+
28112819
// ------------------------------ StoreInterleaved3 (CombineShiftRightBytes,
28122820
// TableLookupBytes)
28132821

hwy/ops/x86_128-inl.h

+47
Original file line numberDiff line numberDiff line change
@@ -6351,6 +6351,53 @@ HWY_API size_t CompressBitsStore(Vec128<T, N> v,
63516351

63526352
#endif // HWY_TARGET <= HWY_AVX3
63536353

6354+
// ------------------------------ StoreInterleaved2
6355+
6356+
// 128 bits
6357+
HWY_API void StoreInterleaved2(const Vec128<uint8_t> v0,
6358+
const Vec128<uint8_t> v1, Full128<uint8_t> d8,
6359+
uint8_t* HWY_RESTRICT unaligned) {
6360+
const RepartitionToWide<decltype(d8)> d16;
6361+
// let a,b denote v0,v1.
6362+
const auto ba0 = ZipLower(d16, v0, v1); // b7 a7 .. b0 a0
6363+
const auto ba8 = ZipUpper(d16, v0, v1);
6364+
StoreU(BitCast(d8, ba0), d8, unaligned + 0 * 16);
6365+
StoreU(BitCast(d8, ba8), d8, unaligned + 1 * 16);
6366+
}
6367+
6368+
// 64 bits
6369+
HWY_API void StoreInterleaved2(const Vec64<uint8_t> in0,
6370+
const Vec64<uint8_t> in1,
6371+
Full64<uint8_t> /*tag*/,
6372+
uint8_t* HWY_RESTRICT unaligned) {
6373+
// Use full vectors to reduce the number of stores.
6374+
const Full128<uint8_t> d_full8;
6375+
const RepartitionToWide<decltype(d_full8)> d16;
6376+
const Vec128<uint8_t> v0{in0.raw};
6377+
const Vec128<uint8_t> v1{in1.raw};
6378+
// let a,b,c,d denote v0,v1.
6379+
const auto ba0 = ZipLower(d16, v0, v1); // b7 a7 .. b0 a0
6380+
StoreU(BitCast(d_full8, ba0), d_full8, unaligned + 0 * 16);
6381+
}
6382+
6383+
// <= 32 bits
6384+
template <size_t N, HWY_IF_LE32(uint8_t, N)>
6385+
HWY_API void StoreInterleaved2(const Vec128<uint8_t, N> in0,
6386+
const Vec128<uint8_t, N> in1,
6387+
Simd<uint8_t, N, 0> /*tag*/,
6388+
uint8_t* HWY_RESTRICT unaligned) {
6389+
// Use full vectors to reduce the number of stores.
6390+
const Full128<uint8_t> d_full8;
6391+
const RepartitionToWide<decltype(d_full8)> d16;
6392+
const Vec128<uint8_t> v0{in0.raw};
6393+
const Vec128<uint8_t> v1{in1.raw};
6394+
// let a,b,c,d denote v0..3.
6395+
const auto ba0 = ZipLower(d16, v0, v1); // b3 a3 .. b0 a0
6396+
alignas(16) uint8_t buf[16];
6397+
StoreU(BitCast(d_full8, ba0), d_full8, buf);
6398+
CopyBytes<2 * N>(buf, unaligned);
6399+
}
6400+
63546401
// ------------------------------ StoreInterleaved3 (CombineShiftRightBytes,
63556402
// TableLookupBytes)
63566403

hwy/ops/x86_256-inl.h

+17
Original file line numberDiff line numberDiff line change
@@ -4786,6 +4786,23 @@ HWY_API size_t CompressBitsStore(Vec256<T> v, const uint8_t* HWY_RESTRICT bits,
47864786

47874787
#endif // HWY_TARGET <= HWY_AVX3
47884788

4789+
// ------------------------------ StoreInterleaved2
4790+
4791+
HWY_API void StoreInterleaved2(const Vec256<uint8_t> v0,
4792+
const Vec256<uint8_t> v1, Full256<uint8_t> d8,
4793+
uint8_t* HWY_RESTRICT unaligned) {
4794+
const RepartitionToWide<decltype(d8)> d16;
4795+
// let a,b denote v0,v1.
4796+
const auto ba0 = ZipLower(d16, v0, v1); // b7 a7 .. b0 a0
4797+
const auto ba8 = ZipUpper(d16, v0, v1);
4798+
// Write lower halves, then upper. vperm2i128 is slow on Zen1 but we can
4799+
// efficiently combine two lower halves into 256 bits:
4800+
const auto out0 = BitCast(d8, ConcatLowerLower(d16, ba8, ba0));
4801+
const auto out1 = BitCast(d8, ConcatUpperUpper(d16, ba8, ba0));
4802+
StoreU(out0, d8, unaligned + 0 * 32);
4803+
StoreU(out1, d8, unaligned + 1 * 32);
4804+
}
4805+
47894806
// ------------------------------ StoreInterleaved3 (CombineShiftRightBytes,
47904807
// TableLookupBytes, ConcatUpperLower)
47914808

hwy/ops/x86_512-inl.h

+20
Original file line numberDiff line numberDiff line change
@@ -3678,6 +3678,26 @@ HWY_API size_t CompressBitsStore(Vec512<T> v, const uint8_t* HWY_RESTRICT bits,
36783678
return CompressStore(v, LoadMaskBits(d, bits), d, unaligned);
36793679
}
36803680

3681+
// ------------------------------ StoreInterleaved2
3682+
3683+
HWY_API void StoreInterleaved2(const Vec512<uint8_t> v0,
3684+
const Vec512<uint8_t> v1, Full512<uint8_t> d8,
3685+
uint8_t* HWY_RESTRICT unaligned) {
3686+
const RepartitionToWide<decltype(d8)> d16;
3687+
// let a,b denote v0,v1.
3688+
const auto i = ZipLower(d16, v0, v1); // b7 a7 .. b0 a0 in lower 128 bits
3689+
const auto j = ZipUpper(d16, v0, v1);
3690+
// 2x4 transpose: interleave 128-bit blocks.
3691+
const __m512i j1_j0_i1_i0 = _mm512_shuffle_i64x2(i.raw, j.raw, _MM_PERM_BABA);
3692+
const __m512i j3_j2_i3_i2 = _mm512_shuffle_i64x2(i.raw, j.raw, _MM_PERM_DCDC);
3693+
const __m512i j1_i1_j0_i0 =
3694+
_mm512_shuffle_i64x2(j1_j0_i1_i0, j1_j0_i1_i0, _MM_PERM_DBCA);
3695+
const __m512i j3_i3_j2_i2 =
3696+
_mm512_shuffle_i64x2(j3_j2_i3_i2, j3_j2_i3_i2, _MM_PERM_DBCA);
3697+
StoreU(Vec512<uint8_t>{j1_i1_j0_i0}, d8, unaligned + 0 * 64);
3698+
StoreU(Vec512<uint8_t>{j3_i3_j2_i2}, d8, unaligned + 1 * 64);
3699+
}
3700+
36813701
// ------------------------------ StoreInterleaved3 (CombineShiftRightBytes,
36823702
// TableLookupBytes)
36833703

hwy/tests/memory_test.cc

+53
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,58 @@ HWY_NOINLINE void TestAllSafeCopyN() {
131131
ForAllTypes(ForPartialVectors<TestSafeCopyN>());
132132
}
133133

134+
struct TestStoreInterleaved2 {
135+
template <class T, class D>
136+
HWY_NOINLINE void operator()(T /*unused*/, D d) {
137+
const size_t N = Lanes(d);
138+
139+
RandomState rng;
140+
141+
// Data to be interleaved
142+
auto bytes = AllocateAligned<uint8_t>(2 * N);
143+
for (size_t i = 0; i < 2 * N; ++i) {
144+
bytes[i] = static_cast<uint8_t>(Random32(&rng) & 0xFF);
145+
}
146+
const auto in0 = Load(d, &bytes[0 * N]);
147+
const auto in1 = Load(d, &bytes[1 * N]);
148+
149+
// Interleave here, ensure vector results match scalar
150+
auto expected = AllocateAligned<T>(3 * N);
151+
auto actual_aligned = AllocateAligned<T>(3 * N + 1);
152+
T* actual = actual_aligned.get() + 1;
153+
154+
for (size_t rep = 0; rep < 100; ++rep) {
155+
for (size_t i = 0; i < N; ++i) {
156+
expected[2 * i + 0] = bytes[0 * N + i];
157+
expected[2 * i + 1] = bytes[1 * N + i];
158+
// Ensure we do not write more than 2*N bytes
159+
expected[2 * N + i] = actual[2 * N + i] = 0;
160+
}
161+
StoreInterleaved2(in0, in1, d, actual);
162+
size_t pos = 0;
163+
if (!BytesEqual(expected.get(), actual, 3 * N, &pos)) {
164+
Print(d, "in0", in0, pos / 4);
165+
Print(d, "in1", in1, pos / 4);
166+
const size_t i = pos;
167+
fprintf(stderr, "interleaved %d %d %d %d %d %d %d %d\n", actual[i],
168+
actual[i + 1], actual[i + 2], actual[i + 3], actual[i + 4],
169+
actual[i + 5], actual[i + 6], actual[i + 7]);
170+
HWY_ASSERT(false);
171+
}
172+
}
173+
}
174+
};
175+
176+
HWY_NOINLINE void TestAllStoreInterleaved2() {
177+
#if HWY_TARGET == HWY_RVV
178+
// Segments are limited to 8 registers, so we can only go up to LMUL=2.
179+
const ForExtendableVectors<TestStoreInterleaved2, 2> test;
180+
#else
181+
const ForPartialVectors<TestStoreInterleaved2> test;
182+
#endif
183+
test(uint8_t());
184+
}
185+
134186
struct TestStoreInterleaved3 {
135187
template <class T, class D>
136188
HWY_NOINLINE void operator()(T /*unused*/, D d) {
@@ -443,6 +495,7 @@ namespace hwy {
443495
HWY_BEFORE_TEST(HwyMemoryTest);
444496
HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllLoadStore);
445497
HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllSafeCopyN);
498+
HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllStoreInterleaved2);
446499
HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllStoreInterleaved3);
447500
HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllStoreInterleaved4);
448501
HWY_EXPORT_AND_TEST_P(HwyMemoryTest, TestAllLoadDup128);

0 commit comments

Comments
 (0)