Skip to content

Commit 215d7b2

Browse files
authored
Merge 6aecd9a into 28bdd95
2 parents 28bdd95 + 6aecd9a commit 215d7b2

File tree

9 files changed

+549
-49
lines changed

9 files changed

+549
-49
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include "util/system/types.h"
4+
5+
enum EFormat : ui8 {
6+
FloatVector = 1, // 4-byte per element
7+
ByteVector = 2, // 1-byte per element
8+
BitVector = 10 // 1-bit per element
9+
};
10+
11+
static constexpr size_t HeaderLen = sizeof(ui8);
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#pragma once
2+
3+
#include "knn-defines.h"
4+
#include "knn-serializer.h"
5+
6+
#include <ydb/library/yql/public/udf/udf_helpers.h>
7+
8+
#include <library/cpp/dot_product/dot_product.h>
9+
#include <util/generic/array_ref.h>
10+
#include <util/generic/buffer.h>
11+
#include <util/stream/format.h>
12+
13+
using namespace NYql;
14+
using namespace NYql::NUdf;
15+
16+
static ui16 KnnManhattanDistance(const TArrayRef<const ui64> vector1, const TArrayRef<const ui64> vector2) {
17+
Y_DEBUG_ABORT_UNLESS(vector1.size() == vector2.size());
18+
Y_DEBUG_ABORT_UNLESS(vector1.size() <= UINT16_MAX);
19+
20+
ui16 ret = 0;
21+
for (size_t i = 0; i < vector1.size(); ++i) {
22+
ret += __builtin_popcountll(vector1[i] ^ vector2[i]);
23+
}
24+
return ret;
25+
}
26+
27+
static std::optional<ui16> KnnManhattanDistance(const TStringRef& str1, const TStringRef& str2) {
28+
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
29+
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];
30+
31+
if (Y_UNLIKELY(format1 != format2 || format1 != EFormat::BitVector))
32+
return {};
33+
34+
const TArrayRef<const ui64> vector1 = TKnnBitVectorSerializer::GetArray64(str1);
35+
const TArrayRef<const ui64> vector2 = TKnnBitVectorSerializer::GetArray64(str2);
36+
37+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector1.size() > UINT16_MAX))
38+
return {};
39+
40+
return KnnManhattanDistance(vector1, vector2);
41+
}
42+
43+
static std::optional<float> KnnDotProduct(const TStringRef& str1, const TStringRef& str2) {
44+
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
45+
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];
46+
47+
if (Y_UNLIKELY(format1 != format2))
48+
return {};
49+
50+
switch (format1) {
51+
case EFormat::FloatVector: {
52+
const TArrayRef<const float> vector1 = TKnnSerializerFacade::GetArray<float>(str1);
53+
const TArrayRef<const float> vector2 = TKnnSerializerFacade::GetArray<float>(str2);
54+
55+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector2.empty()))
56+
return {};
57+
58+
return ::DotProduct(vector1.data(), vector2.data(), vector1.size());
59+
}
60+
case EFormat::ByteVector: {
61+
const TArrayRef<const ui8> vector1 = TKnnSerializerFacade::GetArray<ui8>(str1);
62+
const TArrayRef<const ui8> vector2 = TKnnSerializerFacade::GetArray<ui8>(str2);
63+
64+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector2.empty()))
65+
return {};
66+
67+
return ::DotProduct(vector1.data(), vector2.data(), vector1.size());
68+
}
69+
default:
70+
return {};
71+
}
72+
}
73+
74+
static std::optional<TTriWayDotProduct<float>> KnnTriWayDotProduct(const TStringRef& str1, const TStringRef& str2) {
75+
const ui8 format1 = str1.Data()[str1.Size() - HeaderLen];
76+
const ui8 format2 = str2.Data()[str2.Size() - HeaderLen];
77+
78+
if (Y_UNLIKELY(format1 != format2))
79+
return {};
80+
81+
switch (format1) {
82+
case EFormat::FloatVector: {
83+
const TArrayRef<const float> vector1 = TKnnSerializerFacade::GetArray<float>(str1);
84+
const TArrayRef<const float> vector2 = TKnnSerializerFacade::GetArray<float>(str2);
85+
86+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector2.empty()))
87+
return {};
88+
89+
return ::TriWayDotProduct(vector1.data(), vector2.data(), vector1.size());
90+
}
91+
case EFormat::ByteVector: {
92+
const TArrayRef<const ui8> vector1 = TKnnSerializerFacade::GetArray<ui8>(str1);
93+
const TArrayRef<const ui8> vector2 = TKnnSerializerFacade::GetArray<ui8>(str2);
94+
95+
if (Y_UNLIKELY(vector1.size() != vector2.size() || vector1.empty() || vector2.empty()))
96+
return {};
97+
98+
TTriWayDotProduct<float> result;
99+
result.LL = ::DotProduct(vector1.data(), vector1.data(), vector1.size());
100+
result.LR = ::DotProduct(vector1.data(), vector2.data(), vector1.size());
101+
result.RR = ::DotProduct(vector2.data(), vector2.data(), vector1.size());
102+
return result;
103+
}
104+
default:
105+
return {};
106+
}
107+
}

ydb/library/yql/udfs/common/knn/knn-serializer.h

Lines changed: 103 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include "knn-defines.h"
34
#include "knn-enumerator.h"
45

56
#include <ydb/library/yql/public/udf/udf_helpers.h>
@@ -11,23 +12,21 @@
1112
using namespace NYql;
1213
using namespace NYql::NUdf;
1314

14-
enum EFormat : ui8 {
15-
FloatVector = 1
16-
};
17-
18-
static constexpr size_t HeaderLen = sizeof(ui8);
19-
20-
class TFloatVectorSerializer {
15+
template<typename T, EFormat Format>
16+
class TKnnVectorSerializer {
2117
public:
2218
static TUnboxedValue Serialize(const IValueBuilder* valueBuilder, const TUnboxedValue x) {
2319
auto serialize = [&x] (IOutputStream& outStream) {
24-
EnumerateVector(x, [&outStream] (float element) { outStream.Write(&element, sizeof(float)); });
25-
const EFormat format = EFormat::FloatVector;
20+
EnumerateVector(x, [&outStream] (float floatElement) {
21+
T element = static_cast<T>(floatElement);
22+
outStream.Write(&element, sizeof(T));
23+
});
24+
const EFormat format = Format;
2625
outStream.Write(&format, HeaderLen);
2726
};
2827

2928
if (x.HasFastListLength()) {
30-
auto str = valueBuilder->NewStringNotFilled(HeaderLen + x.GetListLength() * sizeof(float));
29+
auto str = valueBuilder->NewStringNotFilled(HeaderLen + x.GetListLength() * sizeof(T));
3130
auto strRef = str.AsStringRef();
3231
TMemoryOutput memoryOutput(strRef.Data(), strRef.Size());
3332

@@ -46,45 +45,115 @@ class TFloatVectorSerializer {
4645
const char* buf = str.Data();
4746
const size_t len = str.Size() - HeaderLen;
4847

49-
if (len % sizeof(float) != 0)
48+
if (Y_UNLIKELY(len % sizeof(T) != 0))
5049
return {};
5150

52-
const ui32 count = len / sizeof(float);
51+
const ui32 count = len / sizeof(T);
5352

5453
TUnboxedValue* items = nullptr;
5554
auto res = valueBuilder->NewArray(count, items);
5655

5756
TMemoryInput inStr(buf, len);
5857
for (ui32 i = 0; i < count; ++i) {
59-
float element;
60-
if (inStr.Read(&element, sizeof(float)) != sizeof(float))
58+
T element;
59+
if (Y_UNLIKELY(inStr.Read(&element, sizeof(T)) != sizeof(T)))
6160
return {};
62-
*items++ = TUnboxedValuePod{element};
61+
*items++ = TUnboxedValuePod{static_cast<float>(element)};
6362
}
6463

6564
return res.Release();
6665
}
6766

68-
static const TArrayRef<const float> GetArray(const TStringRef& str) {
67+
static const TArrayRef<const T> GetArray(const TStringRef& str) {
6968
const char* buf = str.Data();
7069
const size_t len = str.Size() - HeaderLen;
7170

72-
if (len % sizeof(float) != 0)
71+
if (Y_UNLIKELY(len % sizeof(T) != 0))
7372
return {};
7473

75-
const ui32 count = len / sizeof(float);
74+
const ui32 count = len / sizeof(T);
7675

77-
return MakeArrayRef(reinterpret_cast<const float*>(buf), count);
76+
return MakeArrayRef(reinterpret_cast<const T*>(buf), count);
7877
}
7978
};
8079

80+
// Encode all positive floats as bit 1, negative floats as bit 0.
81+
// So 1024 float vector is serialized in 1024/8=128 bytes.
82+
// Place all bits in ui64. So, only vector sizes divisible by 64 are supported.
83+
// Max vector lenght is 32767.
84+
class TKnnBitVectorSerializer {
85+
public:
86+
static TUnboxedValue Serialize(const IValueBuilder* valueBuilder, const TUnboxedValue x) {
87+
auto serialize = [&x] (IOutputStream& outStream) {
88+
ui64 accumulator = 0;
89+
ui8 filledBits = 0;
90+
ui64 lenght = 0;
91+
92+
EnumerateVector(x, [&] (float element) {
93+
if (element > 0)
94+
accumulator |= 1ll << filledBits;
95+
96+
++filledBits;
97+
if (filledBits == 64) {
98+
outStream.Write(&accumulator, sizeof(ui64));
99+
lenght++;
100+
accumulator = 0;
101+
filledBits = 0;
102+
}
103+
});
104+
105+
// only vector sizes divisible by 64 are supported
106+
if (Y_UNLIKELY(filledBits))
107+
return false;
108+
109+
// max vector lenght is 32767
110+
if (Y_UNLIKELY(lenght > UINT16_MAX))
111+
return false;
112+
113+
const EFormat format = EFormat::BitVector;
114+
outStream.Write(&format, HeaderLen);
115+
116+
return true;
117+
};
118+
119+
if (x.HasFastListLength()) {
120+
auto str = valueBuilder->NewStringNotFilled(HeaderLen + x.GetListLength() / 8);
121+
auto strRef = str.AsStringRef();
122+
TMemoryOutput memoryOutput(strRef.Data(), strRef.Size());
81123

82-
class TSerializerFacade {
124+
if (Y_UNLIKELY(!serialize(memoryOutput)))
125+
return {};
126+
127+
return str;
128+
} else {
129+
TString str;
130+
TStringOutput stringOutput(str);
131+
132+
if (Y_UNLIKELY(!serialize(stringOutput)))
133+
return {};
134+
135+
return valueBuilder->NewString(str);
136+
}
137+
}
138+
139+
static const TArrayRef<const ui64> GetArray64(const TStringRef& str) {
140+
const char* buf = str.Data();
141+
const size_t len = (str.Size() - HeaderLen) / sizeof(ui64);
142+
143+
return MakeArrayRef(reinterpret_cast<const ui64*>(buf), len);
144+
}
145+
};
146+
147+
class TKnnSerializerFacade {
83148
public:
84149
static TUnboxedValue Serialize(EFormat format, const IValueBuilder* valueBuilder, const TUnboxedValue x) {
85150
switch (format) {
86151
case EFormat::FloatVector:
87-
return TFloatVectorSerializer::Serialize(valueBuilder, x);
152+
return TKnnVectorSerializer<float, EFormat::FloatVector>::Serialize(valueBuilder, x);
153+
case EFormat::ByteVector:
154+
return TKnnVectorSerializer<ui8, EFormat::ByteVector>::Serialize(valueBuilder, x);
155+
case EFormat::BitVector:
156+
return TKnnBitVectorSerializer::Serialize(valueBuilder, x);
88157
default:
89158
return {};
90159
}
@@ -97,20 +166,29 @@ class TSerializerFacade {
97166
const ui8 format = str.Data()[str.Size() - HeaderLen];
98167
switch (format) {
99168
case EFormat::FloatVector:
100-
return TFloatVectorSerializer::Deserialize(valueBuilder, str);
169+
return TKnnVectorSerializer<float, EFormat::FloatVector>::Deserialize(valueBuilder, str);
170+
case EFormat::ByteVector:
171+
return TKnnVectorSerializer<ui8, EFormat::ByteVector>::Deserialize(valueBuilder, str);
172+
case EFormat::BitVector:
173+
return {};
101174
default:
102175
return {};
103176
}
104177
}
105178

106-
static const TArrayRef<const float> GetArray(const TStringRef& str) {
107-
if (str.Size() == 0)
179+
template<typename T>
180+
static const TArrayRef<const T> GetArray(const TStringRef& str) {
181+
if (Y_UNLIKELY(str.Size() == 0))
108182
return {};
109183

110184
const ui8 format = str.Data()[str.Size() - HeaderLen];
111185
switch (format) {
112186
case EFormat::FloatVector:
113-
return TFloatVectorSerializer::GetArray(str);
187+
return TKnnVectorSerializer<T, EFormat::FloatVector>::GetArray(str);
188+
case EFormat::ByteVector:
189+
return TKnnVectorSerializer<T, EFormat::ByteVector>::GetArray(str);
190+
case EFormat::BitVector:
191+
return {};
114192
default:
115193
return {};
116194
}

0 commit comments

Comments
 (0)