Skip to content

Commit 3257672

Browse files
authored
Merge 950bff4 into 4d22497
2 parents 4d22497 + 950bff4 commit 3257672

File tree

60 files changed

+22772
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+22772
-0
lines changed

library/cpp/dot_product/README.md

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
Библиотека для вычисления скалярного произведения векторов.
2+
=====================================================
3+
4+
Данная библиотека содержит функцию DotProduct, вычисляющую скалярное произведение векторов различных типов.
5+
В отличии от наивной реализации, библиотека использует SSE и работает существенно быстрее. Для сравнения
6+
можно посмотреть результаты бенчмарка.
7+
8+
Типичное использование - замена кусков кода вроде:
9+
```
10+
for (int i = 0; i < len; i++)
11+
dot_product += a[i] * b[i]);
12+
```
13+
на существенно более эффективный вызов ```DotProduct(a, b, len)```.
14+
15+
Работает для типов i8, i32, float, double.
+274
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
#include "dot_product.h"
2+
#include "dot_product_sse.h"
3+
#include "dot_product_avx2.h"
4+
#include "dot_product_simple.h"
5+
6+
#include <library/cpp/sse/sse.h>
7+
#include <library/cpp/testing/common/env.h>
8+
#include <util/system/compiler.h>
9+
#include <util/generic/utility.h>
10+
#include <util/system/cpu_id.h>
11+
#include <util/system/env.h>
12+
13+
namespace NDotProductImpl {
14+
i32 (*DotProductI8Impl)(const i8* lhs, const i8* rhs, size_t length) noexcept = &DotProductSimple;
15+
ui32 (*DotProductUi8Impl)(const ui8* lhs, const ui8* rhs, size_t length) noexcept = &DotProductSimple;
16+
i64 (*DotProductI32Impl)(const i32* lhs, const i32* rhs, size_t length) noexcept = &DotProductSimple;
17+
float (*DotProductFloatImpl)(const float* lhs, const float* rhs, size_t length) noexcept = &DotProductSimple;
18+
double (*DotProductDoubleImpl)(const double* lhs, const double* rhs, size_t length) noexcept = &DotProductSimple;
19+
20+
namespace {
21+
[[maybe_unused]] const int _ = [] {
22+
if (!FromYaTest() && GetEnv("Y_NO_AVX_IN_DOT_PRODUCT") == "" && NX86::HaveAVX2() && NX86::HaveFMA()) {
23+
DotProductI8Impl = &DotProductAvx2;
24+
DotProductUi8Impl = &DotProductAvx2;
25+
DotProductI32Impl = &DotProductAvx2;
26+
DotProductFloatImpl = &DotProductAvx2;
27+
DotProductDoubleImpl = &DotProductAvx2;
28+
} else {
29+
#ifdef ARCADIA_SSE
30+
DotProductI8Impl = &DotProductSse;
31+
DotProductUi8Impl = &DotProductSse;
32+
DotProductI32Impl = &DotProductSse;
33+
DotProductFloatImpl = &DotProductSse;
34+
DotProductDoubleImpl = &DotProductSse;
35+
#endif
36+
}
37+
return 0;
38+
}();
39+
}
40+
}
41+
42+
#ifdef ARCADIA_SSE
43+
float L2NormSquared(const float* v, size_t length) noexcept {
44+
__m128 sum1 = _mm_setzero_ps();
45+
__m128 sum2 = _mm_setzero_ps();
46+
__m128 a1, a2, m1, m2;
47+
48+
while (length >= 8) {
49+
a1 = _mm_loadu_ps(v);
50+
m1 = _mm_mul_ps(a1, a1);
51+
52+
a2 = _mm_loadu_ps(v + 4);
53+
sum1 = _mm_add_ps(sum1, m1);
54+
55+
m2 = _mm_mul_ps(a2, a2);
56+
sum2 = _mm_add_ps(sum2, m2);
57+
58+
length -= 8;
59+
v += 8;
60+
}
61+
62+
if (length >= 4) {
63+
a1 = _mm_loadu_ps(v);
64+
sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, a1));
65+
66+
length -= 4;
67+
v += 4;
68+
}
69+
70+
sum1 = _mm_add_ps(sum1, sum2);
71+
72+
if (length) {
73+
switch (length) {
74+
case 3:
75+
a1 = _mm_set_ps(0.0f, v[2], v[1], v[0]);
76+
break;
77+
78+
case 2:
79+
a1 = _mm_set_ps(0.0f, 0.0f, v[1], v[0]);
80+
break;
81+
82+
case 1:
83+
a1 = _mm_set_ps(0.0f, 0.0f, 0.0f, v[0]);
84+
break;
85+
86+
default:
87+
Y_UNREACHABLE();
88+
}
89+
90+
sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, a1));
91+
}
92+
93+
alignas(16) float res[4];
94+
_mm_store_ps(res, sum1);
95+
96+
return res[0] + res[1] + res[2] + res[3];
97+
}
98+
99+
template <bool computeLL, bool computeLR, bool computeRR>
100+
Y_FORCE_INLINE
101+
static void TriWayDotProductIteration(__m128& sumLL, __m128& sumLR, __m128& sumRR, const __m128 a, const __m128 b) {
102+
if constexpr (computeLL) {
103+
sumLL = _mm_add_ps(sumLL, _mm_mul_ps(a, a));
104+
}
105+
if constexpr (computeLR) {
106+
sumLR = _mm_add_ps(sumLR, _mm_mul_ps(a, b));
107+
}
108+
if constexpr (computeRR) {
109+
sumRR = _mm_add_ps(sumRR, _mm_mul_ps(b, b));
110+
}
111+
}
112+
113+
114+
template <bool computeLL, bool computeLR, bool computeRR>
115+
static TTriWayDotProduct<float> TriWayDotProductImpl(const float* lhs, const float* rhs, size_t length) noexcept {
116+
__m128 sumLL1 = _mm_setzero_ps();
117+
__m128 sumLR1 = _mm_setzero_ps();
118+
__m128 sumRR1 = _mm_setzero_ps();
119+
__m128 sumLL2 = _mm_setzero_ps();
120+
__m128 sumLR2 = _mm_setzero_ps();
121+
__m128 sumRR2 = _mm_setzero_ps();
122+
123+
while (length >= 8) {
124+
TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, _mm_loadu_ps(lhs + 0), _mm_loadu_ps(rhs + 0));
125+
TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL2, sumLR2, sumRR2, _mm_loadu_ps(lhs + 4), _mm_loadu_ps(rhs + 4));
126+
length -= 8;
127+
lhs += 8;
128+
rhs += 8;
129+
}
130+
131+
if (length >= 4) {
132+
TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, _mm_loadu_ps(lhs + 0), _mm_loadu_ps(rhs + 0));
133+
length -= 4;
134+
lhs += 4;
135+
rhs += 4;
136+
}
137+
138+
if constexpr (computeLL) {
139+
sumLL1 = _mm_add_ps(sumLL1, sumLL2);
140+
}
141+
if constexpr (computeLR) {
142+
sumLR1 = _mm_add_ps(sumLR1, sumLR2);
143+
}
144+
if constexpr (computeRR) {
145+
sumRR1 = _mm_add_ps(sumRR1, sumRR2);
146+
}
147+
148+
if (length) {
149+
__m128 a, b;
150+
switch (length) {
151+
case 3:
152+
a = _mm_set_ps(0.0f, lhs[2], lhs[1], lhs[0]);
153+
b = _mm_set_ps(0.0f, rhs[2], rhs[1], rhs[0]);
154+
break;
155+
case 2:
156+
a = _mm_set_ps(0.0f, 0.0f, lhs[1], lhs[0]);
157+
b = _mm_set_ps(0.0f, 0.0f, rhs[1], rhs[0]);
158+
break;
159+
case 1:
160+
a = _mm_set_ps(0.0f, 0.0f, 0.0f, lhs[0]);
161+
b = _mm_set_ps(0.0f, 0.0f, 0.0f, rhs[0]);
162+
break;
163+
default:
164+
Y_UNREACHABLE();
165+
}
166+
TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, a, b);
167+
}
168+
169+
__m128 t0 = sumLL1;
170+
__m128 t1 = sumLR1;
171+
__m128 t2 = sumRR1;
172+
__m128 t3 = _mm_setzero_ps();
173+
_MM_TRANSPOSE4_PS(t0, t1, t2, t3);
174+
t0 = _mm_add_ps(t0, t1);
175+
t0 = _mm_add_ps(t0, t2);
176+
t0 = _mm_add_ps(t0, t3);
177+
178+
alignas(16) float res[4];
179+
_mm_store_ps(res, t0);
180+
TTriWayDotProduct<float> result{res[0], res[1], res[2]};
181+
static constexpr const TTriWayDotProduct<float> def;
182+
// fill skipped fields with default values
183+
if constexpr (!computeLL) {
184+
result.LL = def.LL;
185+
}
186+
if constexpr (!computeLR) {
187+
result.LR = def.LR;
188+
}
189+
if constexpr (!computeRR) {
190+
result.RR = def.RR;
191+
}
192+
return result;
193+
}
194+
195+
196+
TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept {
197+
mask &= 0b111;
198+
if (Y_LIKELY(mask == 0b111)) { // compute dot-product and length² of two vectors
199+
return TriWayDotProductImpl<true, true, true>(lhs, rhs, length);
200+
} else if (Y_LIKELY(mask == 0b110 || mask == 0b011)) { // compute dot-product and length² of one vector
201+
const bool computeLL = (mask == 0b110);
202+
if (!computeLL) {
203+
DoSwap(lhs, rhs);
204+
}
205+
auto result = TriWayDotProductImpl<true, true, false>(lhs, rhs, length);
206+
if (!computeLL) {
207+
DoSwap(result.LL, result.RR);
208+
}
209+
return result;
210+
} else {
211+
// dispatch unlikely & sparse cases
212+
TTriWayDotProduct<float> result{};
213+
switch(mask) {
214+
case 0b000:
215+
break;
216+
case 0b100:
217+
result.LL = L2NormSquared(lhs, length);
218+
break;
219+
case 0b010:
220+
result.LR = DotProduct(lhs, rhs, length);
221+
break;
222+
case 0b001:
223+
result.RR = L2NormSquared(rhs, length);
224+
break;
225+
case 0b101:
226+
result.LL = L2NormSquared(lhs, length);
227+
result.RR = L2NormSquared(rhs, length);
228+
break;
229+
default:
230+
Y_UNREACHABLE();
231+
}
232+
return result;
233+
}
234+
}
235+
236+
#else
237+
238+
float L2NormSquared(const float* v, size_t length) noexcept {
239+
return DotProduct(v, v, length);
240+
}
241+
242+
TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept {
243+
TTriWayDotProduct<float> result;
244+
if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::LL)) {
245+
result.LL = L2NormSquared(lhs, length);
246+
}
247+
if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::LR)) {
248+
result.LR = DotProduct(lhs, rhs, length);
249+
}
250+
if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::RR)) {
251+
result.RR = L2NormSquared(rhs, length);
252+
}
253+
return result;
254+
}
255+
256+
#endif // ARCADIA_SSE
257+
258+
namespace NDotProduct {
259+
void DisableAvx2() {
260+
#ifdef ARCADIA_SSE
261+
NDotProductImpl::DotProductI8Impl = &DotProductSse;
262+
NDotProductImpl::DotProductUi8Impl = &DotProductSse;
263+
NDotProductImpl::DotProductI32Impl = &DotProductSse;
264+
NDotProductImpl::DotProductFloatImpl = &DotProductSse;
265+
NDotProductImpl::DotProductDoubleImpl = &DotProductSse;
266+
#else
267+
NDotProductImpl::DotProductI8Impl = &DotProductSimple;
268+
NDotProductImpl::DotProductUi8Impl = &DotProductSimple;
269+
NDotProductImpl::DotProductI32Impl = &DotProductSimple;
270+
NDotProductImpl::DotProductFloatImpl = &DotProductSimple;
271+
NDotProductImpl::DotProductDoubleImpl = &DotProductSimple;
272+
#endif
273+
}
274+
}

library/cpp/dot_product/dot_product.h

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#pragma once
2+
3+
#include <util/system/types.h>
4+
#include <util/system/compiler.h>
5+
6+
#include <numeric>
7+
8+
/**
9+
* Dot product (Inner product or scalar product) implementation using SSE when possible.
10+
*/
11+
namespace NDotProductImpl {
12+
extern i32 (*DotProductI8Impl)(const i8* lhs, const i8* rhs, size_t length) noexcept;
13+
extern ui32 (*DotProductUi8Impl)(const ui8* lhs, const ui8* rhs, size_t length) noexcept;
14+
extern i64 (*DotProductI32Impl)(const i32* lhs, const i32* rhs, size_t length) noexcept;
15+
extern float (*DotProductFloatImpl)(const float* lhs, const float* rhs, size_t length) noexcept;
16+
extern double (*DotProductDoubleImpl)(const double* lhs, const double* rhs, size_t length) noexcept;
17+
}
18+
19+
Y_PURE_FUNCTION
20+
inline i32 DotProduct(const i8* lhs, const i8* rhs, size_t length) noexcept {
21+
return NDotProductImpl::DotProductI8Impl(lhs, rhs, length);
22+
}
23+
24+
Y_PURE_FUNCTION
25+
inline ui32 DotProduct(const ui8* lhs, const ui8* rhs, size_t length) noexcept {
26+
return NDotProductImpl::DotProductUi8Impl(lhs, rhs, length);
27+
}
28+
29+
Y_PURE_FUNCTION
30+
inline i64 DotProduct(const i32* lhs, const i32* rhs, size_t length) noexcept {
31+
return NDotProductImpl::DotProductI32Impl(lhs, rhs, length);
32+
}
33+
34+
Y_PURE_FUNCTION
35+
inline float DotProduct(const float* lhs, const float* rhs, size_t length) noexcept {
36+
return NDotProductImpl::DotProductFloatImpl(lhs, rhs, length);
37+
}
38+
39+
Y_PURE_FUNCTION
40+
inline double DotProduct(const double* lhs, const double* rhs, size_t length) noexcept {
41+
return NDotProductImpl::DotProductDoubleImpl(lhs, rhs, length);
42+
}
43+
44+
/**
45+
* Dot product to itself
46+
*/
47+
Y_PURE_FUNCTION
48+
float L2NormSquared(const float* v, size_t length) noexcept;
49+
50+
// TODO(yazevnul): make `L2NormSquared` for double, this should be faster than `DotProduct`
51+
// where `lhs == rhs` because it will save N load instructions.
52+
53+
template <typename T>
54+
struct TTriWayDotProduct {
55+
T LL = 1;
56+
T LR = 0;
57+
T RR = 1;
58+
};
59+
60+
enum class ETriWayDotProductComputeMask: unsigned {
61+
// basic
62+
LL = 0b100,
63+
LR = 0b010,
64+
RR = 0b001,
65+
66+
// useful combinations
67+
All = 0b111,
68+
Left = 0b110, // skip computation of R·R
69+
Right = 0b011, // skip computation of L·L
70+
};
71+
72+
Y_PURE_FUNCTION
73+
TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept;
74+
75+
/**
76+
* For two vectors L and R computes 3 dot-products: L·L, L·R, R·R
77+
*/
78+
Y_PURE_FUNCTION
79+
static inline TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, ETriWayDotProductComputeMask mask = ETriWayDotProductComputeMask::All) noexcept {
80+
return TriWayDotProduct(lhs, rhs, length, static_cast<unsigned>(mask));
81+
}
82+
83+
namespace NDotProduct {
84+
// Simpler wrapper allowing to use this functions as template argument.
85+
template <typename T>
86+
struct TDotProduct {
87+
using TResult = decltype(DotProduct(static_cast<const T*>(nullptr), static_cast<const T*>(nullptr), 0));
88+
Y_PURE_FUNCTION
89+
inline TResult operator()(const T* l, const T* r, size_t length) const {
90+
return DotProduct(l, r, length);
91+
}
92+
};
93+
94+
void DisableAvx2();
95+
}
96+

0 commit comments

Comments
 (0)