Skip to content

Commit 7e15388

Browse files
authored
Merge 91d8f3f into 3d4c044
2 parents 3d4c044 + 91d8f3f commit 7e15388

File tree

14 files changed

+90
-829
lines changed

14 files changed

+90
-829
lines changed

ydb/library/yql/udfs/common/knn/knn-serializer.h

+35-8
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44

55
#include <ydb/library/yql/public/udf/udf_helpers.h>
66

7+
#include <util/generic/array_ref.h>
78
#include <util/generic/buffer.h>
89
#include <util/stream/format.h>
910

1011
using namespace NYql;
1112
using namespace NYql::NUdf;
1213

13-
enum EFormat : ui8 {
14+
enum EFormat : ui32 {
1415
FloatVector = 1
1516
};
1617

@@ -20,12 +21,12 @@ class TFloatVectorSerializer {
2021
static TUnboxedValue Serialize(const IValueBuilder* valueBuilder, const TUnboxedValue x) {
2122
auto serialize = [&x] (IOutputStream& outStream) {
2223
const EFormat format = EFormat::FloatVector;
23-
outStream.Write(&format, 1);
24+
outStream.Write(&format, sizeof(ui32));
2425
EnumerateVector(x, [&outStream] (float element) { outStream.Write(&element, sizeof(float)); });
2526
};
2627

2728
if (x.HasFastListLength()) {
28-
auto str = valueBuilder->NewStringNotFilled(sizeof(ui8) + x.GetListLength() * sizeof(float));
29+
auto str = valueBuilder->NewStringNotFilled(sizeof(ui32) + x.GetListLength() * sizeof(float));
2930
auto strRef = str.AsStringRef();
3031
TMemoryOutput memoryOutput(strRef.Data(), strRef.Size());
3132

@@ -41,9 +42,9 @@ class TFloatVectorSerializer {
4142
}
4243

4344
static TUnboxedValue Deserialize(const IValueBuilder *valueBuilder, const TStringRef& str) {
44-
//skip format byte, it was already read
45-
const char* buf = str.Data() + 1;
46-
const size_t len = str.Size() - 1;
45+
//skip format header, it was already read
46+
const char* buf = str.Data() + sizeof(ui32);
47+
const size_t len = str.Size() - sizeof(ui32);
4748

4849
if (len % sizeof(float) != 0)
4950
return {};
@@ -63,6 +64,19 @@ class TFloatVectorSerializer {
6364

6465
return res.Release();
6566
}
67+
68+
static const TArrayRef<const float> GetArray(const TStringRef& str) {
69+
//skip format header, it was already read
70+
const char* buf = str.Data() + sizeof(ui32);
71+
const size_t len = str.Size() - sizeof(ui32);
72+
73+
if (len % sizeof(float) != 0)
74+
return {};
75+
76+
const ui32 count = len / sizeof(float);
77+
78+
return MakeArrayRef(reinterpret_cast<const float*>(buf), count);
79+
}
6680
};
6781

6882

@@ -81,13 +95,26 @@ class TSerializerFacade {
8195
if (str.Size() == 0)
8296
return {};
8397

84-
ui8 formatByte = str.Data()[0];
85-
switch (formatByte) {
98+
const ui32* format = reinterpret_cast<const ui32*>(str.Data());
99+
switch (*format) {
86100
case EFormat::FloatVector:
87101
return TFloatVectorSerializer::Deserialize(valueBuilder, str);
88102
default:
89103
return {};
90104
}
91105
}
106+
107+
static const TArrayRef<const float> GetArray(const TStringRef& str) {
108+
if (str.Size() == 0)
109+
return {};
110+
111+
const ui32* format = reinterpret_cast<const ui32*>(str.Data());
112+
switch (*format) {
113+
case EFormat::FloatVector:
114+
return TFloatVectorSerializer::GetArray(str);
115+
default:
116+
return {};
117+
}
118+
}
92119
};
93120

ydb/library/yql/udfs/common/knn/knn.cpp

+25-65
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <ydb/library/yql/public/udf/udf_helpers.h>
55

6+
#include <library/cpp/dot_product/dot_product.h>
67
#include <util/generic/buffer.h>
78
#include <util/stream/format.h>
89

@@ -23,94 +24,53 @@ SIMPLE_STRICT_UDF(TFromBinaryString, TOptional<TListType<float>>(TAutoMap<const
2324
return TSerializerFacade::Deserialize(valueBuilder, str);
2425
}
2526

27+
SIMPLE_STRICT_UDF(TInnerProductSimilarity, TOptional<float>(TAutoMap<const char*>, TAutoMap<const char*>)) {
28+
Y_UNUSED(valueBuilder);
2629

27-
std::optional<float> InnerProductSimilarity(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
28-
float ret = 0;
29-
30-
if (!EnumerateVectors(vector1, vector2, [&ret](float el1, float el2) { ret += el1 * el2;}))
31-
return {};
32-
33-
return ret;
34-
}
35-
36-
std::optional<float> CosineSimilarity(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
37-
float len1 = 0;
38-
float len2 = 0;
39-
float innerProduct = 0;
40-
41-
if (!EnumerateVectors(vector1, vector2, [&](float el1, float el2) {
42-
innerProduct += el1 * el2;
43-
len1 += el1 * el1;
44-
len2 += el2 * el2;
45-
}))
46-
return {};
47-
48-
len1 = sqrt(len1);
49-
len2 = sqrt(len2);
50-
51-
float cosine = innerProduct / len1 / len2;
52-
53-
return cosine;
54-
}
55-
56-
std::optional<float> EuclideanDistance(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
57-
float ret = 0;
30+
const TArrayRef<const float> vector1 = TSerializerFacade::GetArray(args[0].AsStringRef());
31+
const TArrayRef<const float> vector2 = TSerializerFacade::GetArray(args[1].AsStringRef());
5832

59-
if (!EnumerateVectors(vector1, vector2, [&ret](float el1, float el2) { ret += (el1 - el2) * (el1 - el2);}))
33+
if (vector1.size() != vector2.size() || vector1.empty() || vector2.empty())
6034
return {};
6135

62-
ret = sqrtf(ret);
63-
64-
return ret;
36+
const float dotProduct = DotProduct(vector1.data(), vector2.data(), vector1.size());
37+
return TUnboxedValuePod{dotProduct};
6538
}
6639

67-
SIMPLE_STRICT_UDF(TInnerProductSimilarity, TOptional<float>(TAutoMap<TListType<float>>, TAutoMap<TListType<float>>)) {
40+
SIMPLE_STRICT_UDF(TCosineSimilarity, TOptional<float>(TAutoMap<const char*>, TAutoMap<const char*>)) {
6841
Y_UNUSED(valueBuilder);
6942

70-
auto innerProduct = InnerProductSimilarity(args[0], args[1]);
71-
if (!innerProduct)
72-
return {};
73-
74-
return TUnboxedValuePod{innerProduct.value()};
75-
}
43+
const TArrayRef<const float> vector1 = TSerializerFacade::GetArray(args[0].AsStringRef());
44+
const TArrayRef<const float> vector2 = TSerializerFacade::GetArray(args[1].AsStringRef());
7645

77-
SIMPLE_STRICT_UDF(TCosineSimilarity, TOptional<float>(TAutoMap<TListType<float>>, TAutoMap<TListType<float>>)) {
78-
Y_UNUSED(valueBuilder);
79-
80-
auto cosine = CosineSimilarity(args[0], args[1]);
81-
if (!cosine)
82-
return {};
46+
if (vector1.size() != vector2.size() || vector1.empty() || vector2.empty())
47+
return {};
8348

84-
return TUnboxedValuePod{cosine.value()};
49+
const auto [ll, lr, rr] = TriWayDotProduct(vector1.data(), vector2.data(), vector1.size());
50+
const float cosine = lr / std::sqrt(ll * rr);
51+
return TUnboxedValuePod{cosine};
8552
}
8653

87-
SIMPLE_STRICT_UDF(TCosineDistance, TOptional<float>(TAutoMap<TListType<float>>, TAutoMap<TListType<float>>)) {
54+
SIMPLE_STRICT_UDF(TCosineDistance, TOptional<float>(TAutoMap<const char*>, TAutoMap<const char*>)) {
8855
Y_UNUSED(valueBuilder);
8956

90-
auto cosine = CosineSimilarity(args[0], args[1]);
91-
if (!cosine)
92-
return {};
93-
94-
return TUnboxedValuePod{1 - cosine.value()};
95-
}
57+
const TArrayRef<const float> vector1 = TSerializerFacade::GetArray(args[0].AsStringRef());
58+
const TArrayRef<const float> vector2 = TSerializerFacade::GetArray(args[1].AsStringRef());
9659

97-
SIMPLE_STRICT_UDF(TEuclideanDistance, TOptional<float>(TAutoMap<TListType<float>>, TAutoMap<TListType<float>>)) {
98-
Y_UNUSED(valueBuilder);
99-
100-
auto distance = EuclideanDistance(args[0], args[1]);
101-
if (!distance)
102-
return {};
60+
if (vector1.size() != vector2.size() || vector1.empty() || vector2.empty())
61+
return {};
10362

104-
return TUnboxedValuePod{distance.value()};
63+
const auto [ll, lr, rr] = TriWayDotProduct(vector1.data(), vector2.data(), vector1.size());
64+
const float cosine = lr / std::sqrt(ll * rr);
65+
return TUnboxedValuePod{1 - cosine};
10566
}
10667

10768
SIMPLE_MODULE(TKnnModule,
10869
TFromBinaryString,
10970
TToBinaryString,
11071
TInnerProductSimilarity,
11172
TCosineSimilarity,
112-
TCosineDistance,
113-
TEuclideanDistance
73+
TCosineDistance
11474
)
11575

11676
REGISTER_MODULES(TKnnModule)

ydb/library/yql/udfs/common/knn/test/canondata/result.json

-5
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,6 @@
1414
"uri": "file://test.test_DeserializationError_/results.txt"
1515
}
1616
],
17-
"test.test[EuclideanDistance]": [
18-
{
19-
"uri": "file://test.test_EuclideanDistance_/results.txt"
20-
}
21-
],
2217
"test.test[InnerProductSimilarity]": [
2318
{
2419
"uri": "file://test.test_InnerProductSimilarity_/results.txt"

ydb/library/yql/udfs/common/knn/test/canondata/test.test_CosineDistance_/results.txt

+2-126
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"Data" = [
2424
[
2525
[
26-
"0.025368154"
26+
"0.025368094"
2727
]
2828
]
2929
]
@@ -54,131 +54,7 @@
5454
"Data" = [
5555
[
5656
[
57-
"0.025368154"
58-
]
59-
]
60-
]
61-
}
62-
]
63-
};
64-
{
65-
"Write" = [
66-
{
67-
"Type" = [
68-
"ListType";
69-
[
70-
"StructType";
71-
[
72-
[
73-
"column0";
74-
[
75-
"OptionalType";
76-
[
77-
"DataType";
78-
"Float"
79-
]
80-
]
81-
]
82-
]
83-
]
84-
];
85-
"Data" = [
86-
[
87-
[
88-
"0.025368154"
89-
]
90-
]
91-
]
92-
}
93-
]
94-
};
95-
{
96-
"Write" = [
97-
{
98-
"Type" = [
99-
"ListType";
100-
[
101-
"StructType";
102-
[
103-
[
104-
"column0";
105-
[
106-
"OptionalType";
107-
[
108-
"DataType";
109-
"Float"
110-
]
111-
]
112-
]
113-
]
114-
]
115-
];
116-
"Data" = [
117-
[
118-
[
119-
"0.025368154"
120-
]
121-
]
122-
]
123-
}
124-
]
125-
};
126-
{
127-
"Write" = [
128-
{
129-
"Type" = [
130-
"ListType";
131-
[
132-
"StructType";
133-
[
134-
[
135-
"column0";
136-
[
137-
"OptionalType";
138-
[
139-
"DataType";
140-
"Float"
141-
]
142-
]
143-
]
144-
]
145-
]
146-
];
147-
"Data" = [
148-
[
149-
[
150-
"0.025368154"
151-
]
152-
]
153-
]
154-
}
155-
]
156-
};
157-
{
158-
"Write" = [
159-
{
160-
"Type" = [
161-
"ListType";
162-
[
163-
"StructType";
164-
[
165-
[
166-
"column0";
167-
[
168-
"OptionalType";
169-
[
170-
"DataType";
171-
"Float"
172-
]
173-
]
174-
]
175-
]
176-
]
177-
];
178-
"Data" = [
179-
[
180-
[
181-
"0.000000059604645"
57+
"0"
18258
]
18359
]
18460
]

0 commit comments

Comments
 (0)