Skip to content

Commit a3fb357

Browse files
authored
Merge 10ba3ef into 3b87a8e
2 parents 3b87a8e + 10ba3ef commit a3fb357

File tree

5 files changed

+60
-20
lines changed

5 files changed

+60
-20
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#pragma once
2+
3+
#include <library/cpp/dot_product/dot_product.h>
4+
5+
// UDF strings are 8-byte aligned.
6+
// SSE prefers 16-byte alignment.
7+
// So, there are four cases:
8+
// 1. lhs % 16 == 0, rhs % 16 == 0
9+
// 2. lhs % 16 == 8, rhs % 16 == 0
10+
// 3. lhs % 16 == 0, rhs % 16 == 8
11+
// 4. lhs % 16 == 8, rhs % 16 == 8
12+
13+
// In case 4 we divide aligned and unaligned elements.
14+
// So, we make two DotProduct calls. It can impove performance.
15+
16+
// Case 1 is perfectly aligned for SSE. Read will be aligned.
17+
// Cases 2,3 are badly aligned. Read will be unaligned.
18+
19+
inline float KnnDotProduct(const float* lhs, const float* rhs, size_t length) {
20+
if ((size_t)lhs % 16 == 8 && (size_t)rhs % 16 == 8) {
21+
const size_t numUnaligned = 8 / sizeof(float);
22+
const float resUnaligned = DotProduct(lhs, rhs, numUnaligned);
23+
const float resAligned = DotProduct(lhs + numUnaligned, rhs + numUnaligned, length - numUnaligned);
24+
return resUnaligned + resAligned;
25+
} else {
26+
return DotProduct(lhs, rhs, length);
27+
}
28+
}
29+
30+
static inline TTriWayDotProduct<float> KnnTriWayDotProduct(const float* lhs, const float* rhs, size_t length) {
31+
if ((size_t)lhs % 16 == 8 && (size_t)rhs % 16 == 8) {
32+
const size_t numUnaligned = 8 / sizeof(float);
33+
const auto resUnaligned = TriWayDotProduct(lhs, rhs, numUnaligned);
34+
const auto resAligned = TriWayDotProduct(lhs + numUnaligned, rhs + numUnaligned, length - numUnaligned);
35+
return { resUnaligned.LL + resAligned.LL,
36+
resUnaligned.LR + resAligned.LR,
37+
resUnaligned.RR + resAligned.RR };
38+
} else {
39+
return TriWayDotProduct(lhs, rhs, length);
40+
}
41+
}

ydb/library/yql/udfs/common/knn/knn-serializer.h

+13-14
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,23 @@
1111
using namespace NYql;
1212
using namespace NYql::NUdf;
1313

14-
enum EFormat : ui32 {
14+
enum EFormat : ui8 {
1515
FloatVector = 1
1616
};
1717

18+
static constexpr size_t HeaderLen = sizeof(ui8);
1819

1920
class TFloatVectorSerializer {
2021
public:
2122
static TUnboxedValue Serialize(const IValueBuilder* valueBuilder, const TUnboxedValue x) {
2223
auto serialize = [&x] (IOutputStream& outStream) {
23-
const EFormat format = EFormat::FloatVector;
24-
outStream.Write(&format, sizeof(ui32));
2524
EnumerateVector(x, [&outStream] (float element) { outStream.Write(&element, sizeof(float)); });
25+
const EFormat format = EFormat::FloatVector;
26+
outStream.Write(&format, HeaderLen);
2627
};
2728

2829
if (x.HasFastListLength()) {
29-
auto str = valueBuilder->NewStringNotFilled(sizeof(ui32) + x.GetListLength() * sizeof(float));
30+
auto str = valueBuilder->NewStringNotFilled(HeaderLen + x.GetListLength() * sizeof(float));
3031
auto strRef = str.AsStringRef();
3132
TMemoryOutput memoryOutput(strRef.Data(), strRef.Size());
3233

@@ -42,9 +43,8 @@ class TFloatVectorSerializer {
4243
}
4344

4445
static TUnboxedValue Deserialize(const IValueBuilder *valueBuilder, const TStringRef& str) {
45-
//skip format header, it was already read
46-
const char* buf = str.Data() + sizeof(ui32);
47-
const size_t len = str.Size() - sizeof(ui32);
46+
const char* buf = str.Data();
47+
const size_t len = str.Size() - HeaderLen;
4848

4949
if (len % sizeof(float) != 0)
5050
return {};
@@ -66,9 +66,8 @@ class TFloatVectorSerializer {
6666
}
6767

6868
static const TArrayRef<const float> GetArray(const TStringRef& str) {
69-
//skip format header, it was already read
70-
const char* buf = str.Data() + sizeof(ui32);
71-
const size_t len = str.Size() - sizeof(ui32);
69+
const char* buf = str.Data();
70+
const size_t len = str.Size() - HeaderLen;
7271

7372
if (len % sizeof(float) != 0)
7473
return {};
@@ -95,8 +94,8 @@ class TSerializerFacade {
9594
if (str.Size() == 0)
9695
return {};
9796

98-
const ui32* format = reinterpret_cast<const ui32*>(str.Data());
99-
switch (*format) {
97+
const ui8 format = str.Data()[str.Size() - HeaderLen];
98+
switch (format) {
10099
case EFormat::FloatVector:
101100
return TFloatVectorSerializer::Deserialize(valueBuilder, str);
102101
default:
@@ -108,8 +107,8 @@ class TSerializerFacade {
108107
if (str.Size() == 0)
109108
return {};
110109

111-
const ui32* format = reinterpret_cast<const ui32*>(str.Data());
112-
switch (*format) {
110+
const ui8 format = str.Data()[str.Size() - HeaderLen];
111+
switch (format) {
113112
case EFormat::FloatVector:
114113
return TFloatVectorSerializer::GetArray(str);
115114
default:

ydb/library/yql/udfs/common/knn/knn.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
#include "knn-distance.h"
12
#include "knn-enumerator.h"
23
#include "knn-serializer.h"
34

45
#include <ydb/library/yql/public/udf/udf_helpers.h>
56

6-
#include <library/cpp/dot_product/dot_product.h>
77
#include <util/generic/buffer.h>
88
#include <util/stream/format.h>
99

@@ -33,7 +33,7 @@ SIMPLE_STRICT_UDF(TInnerProductSimilarity, TOptional<float>(TAutoMap<const char*
3333
if (vector1.size() != vector2.size() || vector1.empty() || vector2.empty())
3434
return {};
3535

36-
const float dotProduct = DotProduct(vector1.data(), vector2.data(), vector1.size());
36+
const float dotProduct = KnnDotProduct(vector1.data(), vector2.data(), vector1.size());
3737
return TUnboxedValuePod{dotProduct};
3838
}
3939

@@ -46,7 +46,7 @@ SIMPLE_STRICT_UDF(TCosineSimilarity, TOptional<float>(TAutoMap<const char*>, TAu
4646
if (vector1.size() != vector2.size() || vector1.empty() || vector2.empty())
4747
return {};
4848

49-
const auto [ll, lr, rr] = TriWayDotProduct(vector1.data(), vector2.data(), vector1.size());
49+
const auto [ll, lr, rr] = KnnTriWayDotProduct(vector1.data(), vector2.data(), vector1.size());
5050
const float cosine = lr / std::sqrt(ll * rr);
5151
return TUnboxedValuePod{cosine};
5252
}
@@ -60,7 +60,7 @@ SIMPLE_STRICT_UDF(TCosineDistance, TOptional<float>(TAutoMap<const char*>, TAuto
6060
if (vector1.size() != vector2.size() || vector1.empty() || vector2.empty())
6161
return {};
6262

63-
const auto [ll, lr, rr] = TriWayDotProduct(vector1.data(), vector2.data(), vector1.size());
63+
const auto [ll, lr, rr] = KnnTriWayDotProduct(vector1.data(), vector2.data(), vector1.size());
6464
const float cosine = lr / std::sqrt(ll * rr);
6565
return TUnboxedValuePod{1 - cosine};
6666
}

ydb/library/yql/udfs/common/knn/test/canondata/test.test_LazyListSerialization_/results.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"Data" = [
2121
[
2222
[
23-
"AQAAAAAAgD8AAABAAABAQAAAgEAAAKBA"
23+
"AACAPwAAAEAAAEBAAACAQAAAoEAB"
2424
]
2525
]
2626
]

ydb/library/yql/udfs/common/knn/test/canondata/test.test_ListSerialization_/results.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"Data" = [
2121
[
2222
[
23-
"AQAAAJqZmT8zMxNAmplZQAAAkEAzM7NA"
23+
"mpmZPzMzE0CamVlAAACQQDMzs0AB"
2424
]
2525
]
2626
]

0 commit comments

Comments
 (0)