Skip to content

Commit bd029d2

Browse files
authored
Merge 8a24959 into 2abba45
2 parents 2abba45 + 8a24959 commit bd029d2

File tree

13 files changed

+757
-0
lines changed

13 files changed

+757
-0
lines changed
+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
#include <ydb/library/yql/public/udf/udf_helpers.h>
2+
3+
#include <util/generic/buffer.h>
4+
#include <util/stream/format.h>
5+
6+
using namespace NYql;
7+
using namespace NYql::NUdf;
8+
9+
enum EFormat : unsigned char {
10+
FloatVector = 1
11+
};
12+
13+
TString SerializeFloatVector(const TUnboxedValuePod x) {
14+
const EFormat format = EFormat::FloatVector;
15+
16+
TString str;
17+
TStringOutput outStr(str);
18+
if (const auto elements = x.GetElements()) {
19+
const auto size = x.GetListLength();
20+
outStr.Reserve(1 + size * sizeof(float));
21+
outStr.Write(&format, sizeof(unsigned char));
22+
for (ui32 i = 0; i < size; ++i) {
23+
float element = elements[i].Get<float>();
24+
outStr.Write(&element, sizeof(float));
25+
}
26+
} else {
27+
outStr.Write(&format, sizeof(unsigned char));
28+
const auto it = x.GetListIterator();
29+
TUnboxedValue v;
30+
while(it.Next(v)) {
31+
float element = v.Get<float>();
32+
outStr.Write(&element, sizeof(float));
33+
}
34+
}
35+
return str;
36+
}
37+
38+
NYql::NUdf::TUnboxedValue DeserializeFloatVector(const IValueBuilder *valueBuilder, TStringRef str) {
39+
if (str.Size() % sizeof(float) != 0)
40+
return {};
41+
42+
const ui32 count = str.Size() / sizeof(float);
43+
44+
TUnboxedValue* items = nullptr;
45+
auto res = valueBuilder->NewArray(count, items);
46+
47+
TMemoryInput inStr(str);
48+
for (ui32 i = 0; i < count; ++i) {
49+
float element;
50+
if (inStr.Read(&element, sizeof(float)) != sizeof(float))
51+
return {};
52+
*items++ = TUnboxedValuePod{element};
53+
}
54+
55+
return res.Release();
56+
}
57+
58+
SIMPLE_STRICT_UDF(TToBinaryString, char*(TListType<float>)) {
59+
return valueBuilder->NewString(SerializeFloatVector(args[0]));
60+
}
61+
62+
SIMPLE_STRICT_UDF(TFromBinaryString, TOptional<TListType<float>>(const char*)) {
63+
TStringRef str = args[0].AsStringRef();
64+
if (str.Size() == 0)
65+
return {};
66+
67+
const EFormat format = static_cast<EFormat>(str.Data()[0]);
68+
str = TStringRef{str.Data() + 1, str.Size() -1};
69+
switch (format) {
70+
case EFormat::FloatVector:
71+
return DeserializeFloatVector(valueBuilder, str);
72+
default:
73+
return {};
74+
}
75+
}
76+
77+
bool EnumerateVectors(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2, std::function<void(float, float)> callback) {
78+
79+
auto enumerateBothSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValue* elements1, const TUnboxedValuePod vector2, const TUnboxedValue* elements2) {
80+
const auto size1 = vector1.GetListLength();
81+
const auto size2 = vector2.GetListLength();
82+
83+
// Lenght mismatch
84+
if (size1 != size2)
85+
return false;
86+
87+
for (ui32 i = 0; i < size1; ++i) {
88+
callback(elements1[i].Get<float>(), elements2[i].Get<float>());
89+
}
90+
91+
return true;
92+
};
93+
94+
auto enumerateOneSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValue* elements1, const TUnboxedValuePod vector2) {
95+
const auto size = vector1.GetListLength();
96+
ui32 idx = 0;
97+
TUnboxedValue value;
98+
const auto it = vector2.GetListIterator();
99+
100+
while (it.Next(value)) {
101+
callback(elements1[idx++].Get<float>(), value.Get<float>());
102+
}
103+
104+
// Lenght mismatch
105+
if (it.Next(value) || idx != size)
106+
return false;
107+
108+
return true;
109+
};
110+
111+
auto enumerateNoSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
112+
TUnboxedValue value1, value2;
113+
const auto it1 = vector1.GetListIterator();
114+
const auto it2 = vector2.GetListIterator();
115+
for (; it1.Next(value1) && it2.Next(value2);) {
116+
callback(value1.Get<float>(), value2.Get<float>());
117+
}
118+
119+
// Lenght mismatch
120+
if (it1.Next(value1) || it2.Next(value2))
121+
return false;
122+
123+
return true;
124+
};
125+
126+
const auto elements1 = vector1.GetElements();
127+
const auto elements2 = vector2.GetElements();
128+
if (elements1 && elements2) {
129+
if (!enumerateBothSized(vector1, elements1, vector2, elements2))
130+
return false;
131+
} else if (elements1) {
132+
if (!enumerateOneSized(vector1, elements1, vector2))
133+
return false;
134+
} else if (elements2) {
135+
if (!enumerateOneSized(vector2, elements2, vector1))
136+
return false;
137+
} else {
138+
if (!enumerateNoSized(vector1, vector2))
139+
return false;
140+
}
141+
142+
return true;
143+
}
144+
145+
std::optional<float> InnerProductDistance(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
146+
float ret = 0;
147+
148+
if (!EnumerateVectors(vector1, vector2, [&ret](float el1, float el2) { ret += el1 * el2;}))
149+
return {};
150+
151+
return ret;
152+
}
153+
154+
SIMPLE_STRICT_UDF(TInnerProductDistance, TOptional<float>(TOptional<TListType<float>>, TOptional<TListType<float>>)) {
155+
Y_UNUSED(valueBuilder);
156+
157+
if (!args[0].HasValue() || !args[1].HasValue())
158+
return {};
159+
160+
auto distance = InnerProductDistance(args[0], args[1]);
161+
if (!distance)
162+
return {};
163+
164+
return TUnboxedValuePod{distance.value()};
165+
}
166+
167+
SIMPLE_MODULE(TKnnModule,
168+
TFromBinaryString,
169+
TToBinaryString,
170+
TInnerProductDistance
171+
)
172+
173+
REGISTER_MODULES(TKnnModule)
174+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"test.test[DeserializationError]": [
3+
{
4+
"uri": "file://test.test_DeserializationError_/results.txt"
5+
}
6+
],
7+
"test.test[InnerProductDistance]": [
8+
{
9+
"uri": "file://test.test_InnerProductDistance_/results.txt"
10+
}
11+
],
12+
"test.test[LazyListSerialization]": [
13+
{
14+
"uri": "file://test.test_LazyListSerialization_/results.txt"
15+
}
16+
],
17+
"test.test[ListSerialization]": [
18+
{
19+
"uri": "file://test.test_ListSerialization_/results.txt"
20+
}
21+
]
22+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
[
2+
{
3+
"Write" = [
4+
{
5+
"Type" = [
6+
"ListType";
7+
[
8+
"StructType";
9+
[
10+
[
11+
"column0";
12+
[
13+
"OptionalType";
14+
[
15+
"ListType";
16+
[
17+
"DataType";
18+
"Float"
19+
]
20+
]
21+
]
22+
]
23+
]
24+
]
25+
];
26+
"Data" = [
27+
[
28+
#
29+
]
30+
]
31+
}
32+
]
33+
};
34+
{
35+
"Write" = [
36+
{
37+
"Type" = [
38+
"ListType";
39+
[
40+
"StructType";
41+
[
42+
[
43+
"column0";
44+
[
45+
"OptionalType";
46+
[
47+
"ListType";
48+
[
49+
"DataType";
50+
"Float"
51+
]
52+
]
53+
]
54+
]
55+
]
56+
]
57+
];
58+
"Data" = [
59+
[
60+
#
61+
]
62+
]
63+
}
64+
]
65+
}
66+
]

0 commit comments

Comments
 (0)