Skip to content

Commit 8c86d97

Browse files
committed
Exact vector search (ydb-platform#2519)
1 parent ed0f418 commit 8c86d97

File tree

21 files changed

+1837
-0
lines changed

21 files changed

+1837
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#pragma once
2+
3+
#include <ydb/library/yql/public/udf/udf_helpers.h>
4+
5+
#include <util/generic/buffer.h>
6+
#include <util/stream/format.h>
7+
8+
using namespace NYql;
9+
using namespace NYql::NUdf;
10+
11+
template <typename TCallback>
12+
void EnumerateVector(const TUnboxedValuePod vector, TCallback&& callback) {
13+
const auto elements = vector.GetElements();
14+
if (elements) {
15+
const auto size = vector.GetListLength();
16+
17+
for (ui32 i = 0; i < size; ++i) {
18+
callback(elements[i].Get<float>());
19+
}
20+
} else {
21+
TUnboxedValue value;
22+
const auto it = vector.GetListIterator();
23+
while (it.Next(value)) {
24+
callback(value.Get<float>());
25+
}
26+
}
27+
}
28+
29+
template <typename TCallback>
30+
bool EnumerateVectors(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2, TCallback&& callback) {
31+
32+
auto enumerateBothSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValue* elements1, const TUnboxedValuePod vector2, const TUnboxedValue* elements2) {
33+
const auto size1 = vector1.GetListLength();
34+
const auto size2 = vector2.GetListLength();
35+
36+
// Length mismatch
37+
if (size1 != size2)
38+
return false;
39+
40+
for (ui32 i = 0; i < size1; ++i) {
41+
callback(elements1[i].Get<float>(), elements2[i].Get<float>());
42+
}
43+
44+
return true;
45+
};
46+
47+
auto enumerateOneSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValue* elements1, const TUnboxedValuePod vector2) {
48+
const auto size = vector1.GetListLength();
49+
ui32 idx = 0;
50+
TUnboxedValue value;
51+
const auto it = vector2.GetListIterator();
52+
53+
while (it.Next(value)) {
54+
callback(elements1[idx++].Get<float>(), value.Get<float>());
55+
}
56+
57+
// Length mismatch
58+
if (it.Next(value) || idx != size)
59+
return false;
60+
61+
return true;
62+
};
63+
64+
auto enumerateNoSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
65+
TUnboxedValue value1, value2;
66+
const auto it1 = vector1.GetListIterator();
67+
const auto it2 = vector2.GetListIterator();
68+
for (; it1.Next(value1) && it2.Next(value2);) {
69+
callback(value1.Get<float>(), value2.Get<float>());
70+
}
71+
72+
// Length mismatch
73+
if (it1.Next(value1) || it2.Next(value2))
74+
return false;
75+
76+
return true;
77+
};
78+
79+
const auto elements1 = vector1.GetElements();
80+
const auto elements2 = vector2.GetElements();
81+
if (elements1 && elements2) {
82+
if (!enumerateBothSized(vector1, elements1, vector2, elements2))
83+
return false;
84+
} else if (elements1) {
85+
if (!enumerateOneSized(vector1, elements1, vector2))
86+
return false;
87+
} else if (elements2) {
88+
if (!enumerateOneSized(vector2, elements2, vector1))
89+
return false;
90+
} else {
91+
if (!enumerateNoSized(vector1, vector2))
92+
return false;
93+
}
94+
95+
return true;
96+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#pragma once
2+
3+
#include "knn-enumerator.h"
4+
5+
#include <ydb/library/yql/public/udf/udf_helpers.h>
6+
7+
#include <util/generic/buffer.h>
8+
#include <util/stream/format.h>
9+
10+
using namespace NYql;
11+
using namespace NYql::NUdf;
12+
13+
enum EFormat : ui8 {
14+
FloatVector = 1
15+
};
16+
17+
18+
class TFloatVectorSerializer {
19+
public:
20+
static TUnboxedValue Serialize(const IValueBuilder* valueBuilder, const TUnboxedValue x) {
21+
auto serialize = [&x] (IOutputStream& outStream) {
22+
const EFormat format = EFormat::FloatVector;
23+
outStream.Write(&format, 1);
24+
EnumerateVector(x, [&outStream] (float element) { outStream.Write(&element, sizeof(float)); });
25+
};
26+
27+
if (x.HasFastListLength()) {
28+
auto str = valueBuilder->NewStringNotFilled(sizeof(ui8) + x.GetListLength() * sizeof(float));
29+
auto strRef = str.AsStringRef();
30+
TMemoryOutput memoryOutput(strRef.Data(), strRef.Size());
31+
32+
serialize(memoryOutput);
33+
return str;
34+
} else {
35+
TString str;
36+
TStringOutput stringOutput(str);
37+
38+
serialize(stringOutput);
39+
return valueBuilder->NewString(str);
40+
}
41+
}
42+
43+
static TUnboxedValue Deserialize(const IValueBuilder *valueBuilder, const TStringRef& str) {
44+
//skip format byte, it was already read
45+
const char* buf = str.Data() + 1;
46+
const size_t len = str.Size() - 1;
47+
48+
if (len % sizeof(float) != 0)
49+
return {};
50+
51+
const ui32 count = len / sizeof(float);
52+
53+
TUnboxedValue* items = nullptr;
54+
auto res = valueBuilder->NewArray(count, items);
55+
56+
TMemoryInput inStr(buf, len);
57+
for (ui32 i = 0; i < count; ++i) {
58+
float element;
59+
if (inStr.Read(&element, sizeof(float)) != sizeof(float))
60+
return {};
61+
*items++ = TUnboxedValuePod{element};
62+
}
63+
64+
return res.Release();
65+
}
66+
};
67+
68+
69+
class TSerializerFacade {
70+
public:
71+
static TUnboxedValue Serialize(EFormat format, const IValueBuilder* valueBuilder, const TUnboxedValue x) {
72+
switch (format) {
73+
case EFormat::FloatVector:
74+
return TFloatVectorSerializer::Serialize(valueBuilder, x);
75+
default:
76+
return {};
77+
}
78+
}
79+
80+
static TUnboxedValue Deserialize(const IValueBuilder *valueBuilder, const TStringRef& str) {
81+
if (str.Size() == 0)
82+
return {};
83+
84+
ui8 formatByte = str.Data()[0];
85+
switch (formatByte) {
86+
case EFormat::FloatVector:
87+
return TFloatVectorSerializer::Deserialize(valueBuilder, str);
88+
default:
89+
return {};
90+
}
91+
}
92+
};
93+
+117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#include "knn-enumerator.h"
2+
#include "knn-serializer.h"
3+
4+
#include <ydb/library/yql/public/udf/udf_helpers.h>
5+
6+
#include <util/generic/buffer.h>
7+
#include <util/stream/format.h>
8+
9+
using namespace NYql;
10+
using namespace NYql::NUdf;
11+
12+
13+
SIMPLE_STRICT_UDF(TToBinaryString, char*(TAutoMap<TListType<float>>)) {
14+
const TUnboxedValuePod x = args[0];
15+
const EFormat format = EFormat::FloatVector; // will be taken from args in future
16+
17+
return TSerializerFacade::Serialize(format, valueBuilder, x);
18+
}
19+
20+
SIMPLE_STRICT_UDF(TFromBinaryString, TOptional<TListType<float>>(const char*)) {
21+
TStringRef str = args[0].AsStringRef();
22+
23+
return TSerializerFacade::Deserialize(valueBuilder, str);
24+
}
25+
26+
27+
std::optional<float> InnerProductSimilarity(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
28+
float ret = 0;
29+
30+
if (!EnumerateVectors(vector1, vector2, [&ret](float el1, float el2) { ret += el1 * el2;}))
31+
return {};
32+
33+
return ret;
34+
}
35+
36+
std::optional<float> CosineSimilarity(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
37+
float len1 = 0;
38+
float len2 = 0;
39+
float innerProduct = 0;
40+
41+
if (!EnumerateVectors(vector1, vector2, [&](float el1, float el2) {
42+
innerProduct += el1 * el2;
43+
len1 += el1 * el1;
44+
len2 += el2 * el2;
45+
}))
46+
return {};
47+
48+
len1 = sqrt(len1);
49+
len2 = sqrt(len2);
50+
51+
float cosine = innerProduct / len1 / len2;
52+
53+
return cosine;
54+
}
55+
56+
std::optional<float> EuclideanDistance(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
57+
float ret = 0;
58+
59+
if (!EnumerateVectors(vector1, vector2, [&ret](float el1, float el2) { ret += (el1 - el2) * (el1 - el2);}))
60+
return {};
61+
62+
ret = sqrtf(ret);
63+
64+
return ret;
65+
}
66+
67+
SIMPLE_STRICT_UDF(TInnerProductSimilarity, TOptional<float>(TAutoMap<TListType<float>>, TAutoMap<TListType<float>>)) {
68+
Y_UNUSED(valueBuilder);
69+
70+
auto innerProduct = InnerProductSimilarity(args[0], args[1]);
71+
if (!innerProduct)
72+
return {};
73+
74+
return TUnboxedValuePod{innerProduct.value()};
75+
}
76+
77+
SIMPLE_STRICT_UDF(TCosineSimilarity, TOptional<float>(TAutoMap<TListType<float>>, TAutoMap<TListType<float>>)) {
78+
Y_UNUSED(valueBuilder);
79+
80+
auto cosine = CosineSimilarity(args[0], args[1]);
81+
if (!cosine)
82+
return {};
83+
84+
return TUnboxedValuePod{cosine.value()};
85+
}
86+
87+
SIMPLE_STRICT_UDF(TCosineDistance, TOptional<float>(TAutoMap<TListType<float>>, TAutoMap<TListType<float>>)) {
88+
Y_UNUSED(valueBuilder);
89+
90+
auto cosine = CosineSimilarity(args[0], args[1]);
91+
if (!cosine)
92+
return {};
93+
94+
return TUnboxedValuePod{1 - cosine.value()};
95+
}
96+
97+
SIMPLE_STRICT_UDF(TEuclideanDistance, TOptional<float>(TAutoMap<TListType<float>>, TAutoMap<TListType<float>>)) {
98+
Y_UNUSED(valueBuilder);
99+
100+
auto distance = EuclideanDistance(args[0], args[1]);
101+
if (!distance)
102+
return {};
103+
104+
return TUnboxedValuePod{distance.value()};
105+
}
106+
107+
SIMPLE_MODULE(TKnnModule,
108+
TFromBinaryString,
109+
TToBinaryString,
110+
TInnerProductSimilarity,
111+
TCosineSimilarity,
112+
TCosineDistance,
113+
TEuclideanDistance
114+
)
115+
116+
REGISTER_MODULES(TKnnModule)
117+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
{
2+
"test.test[CosineDistance]": [
3+
{
4+
"uri": "file://test.test_CosineDistance_/results.txt"
5+
}
6+
],
7+
"test.test[CosineSimilarity]": [
8+
{
9+
"uri": "file://test.test_CosineSimilarity_/results.txt"
10+
}
11+
],
12+
"test.test[DeserializationError]": [
13+
{
14+
"uri": "file://test.test_DeserializationError_/results.txt"
15+
}
16+
],
17+
"test.test[EuclideanDistance]": [
18+
{
19+
"uri": "file://test.test_EuclideanDistance_/results.txt"
20+
}
21+
],
22+
"test.test[InnerProductSimilarity]": [
23+
{
24+
"uri": "file://test.test_InnerProductSimilarity_/results.txt"
25+
}
26+
],
27+
"test.test[LazyListSerialization]": [
28+
{
29+
"uri": "file://test.test_LazyListSerialization_/results.txt"
30+
}
31+
],
32+
"test.test[ListSerialization]": [
33+
{
34+
"uri": "file://test.test_ListSerialization_/results.txt"
35+
}
36+
]
37+
}

0 commit comments

Comments
 (0)