Skip to content

Commit 3e5ea58

Browse files
authored
Merge 897ee20 into bcbb8e1
2 parents bcbb8e1 + 897ee20 commit 3e5ea58

File tree

13 files changed

+755
-0
lines changed

13 files changed

+755
-0
lines changed
+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#include <ydb/library/yql/public/udf/udf_helpers.h>
2+
3+
#include <util/generic/buffer.h>
4+
#include <util/stream/format.h>
5+
6+
using namespace NYql;
7+
using namespace NYql::NUdf;
8+
9+
enum EFormat : unsigned char {
10+
DoubleVector = 1
11+
};
12+
13+
TString SerializeDoubleVector(const TUnboxedValuePod x) {
14+
const EFormat format = EFormat::DoubleVector;
15+
16+
TString str;
17+
TStringOutput outStr(str);
18+
if (const auto elements = x.GetElements()) {
19+
const auto size = x.GetListLength();
20+
outStr.Reserve(1 + size * sizeof(double));
21+
outStr.Write(&format, sizeof(unsigned char));
22+
for (ui32 i = 0; i < size; ++i) {
23+
double element = elements[i].Get<double>();
24+
outStr.Write(&element, sizeof(double));
25+
}
26+
} else {
27+
outStr.Write(&format, sizeof(unsigned char));
28+
const auto it = x.GetListIterator();
29+
TUnboxedValue v;
30+
while(it.Next(v)) {
31+
double element = v.Get<double>();
32+
outStr.Write(&element, sizeof(double));
33+
}
34+
}
35+
return str;
36+
}
37+
38+
NYql::NUdf::TUnboxedValue DeserializeDoubleVector(const IValueBuilder *valueBuilder, TStringRef str) {
39+
if (str.Size() % sizeof(double) != 0)
40+
return {};
41+
42+
const ui32 count = str.Size() / sizeof(double);
43+
44+
TUnboxedValue* items = nullptr;
45+
auto res = valueBuilder->NewArray(count, items);
46+
47+
for (ui32 i = 0; i < count; ++i) {
48+
double element;
49+
memcpy(&element, str.Data() + i * sizeof(double), sizeof(double));
50+
*items++ = TUnboxedValuePod{element};
51+
}
52+
53+
return res.Release();
54+
}
55+
56+
SIMPLE_STRICT_UDF(TToBinaryString, char*(TListType<double>)) {
57+
return valueBuilder->NewString(SerializeDoubleVector(args[0]));
58+
}
59+
60+
SIMPLE_STRICT_UDF(TFromBinaryString, TOptional<TListType<double>>(const char*)) {
61+
TStringRef str = args[0].AsStringRef();
62+
if (str.Size() == 0)
63+
return {};
64+
65+
const EFormat format = static_cast<EFormat>(str.Data()[0]);
66+
str = TStringRef{str.Data() + 1, str.Size() -1};
67+
switch (format) {
68+
case EFormat::DoubleVector:
69+
return DeserializeDoubleVector(valueBuilder, str);
70+
default:
71+
return {};
72+
}
73+
}
74+
75+
bool EnumerateVectors(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2, std::function<void(double, double)> callback) {
76+
77+
auto enumerateBothSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValue* elements1, const TUnboxedValuePod vector2, const TUnboxedValue* elements2) {
78+
const auto size1 = vector1.GetListLength();
79+
const auto size2 = vector2.GetListLength();
80+
81+
// Lenght mismatch
82+
if (size1 != size2)
83+
return false;
84+
85+
for (ui32 i = 0; i < size1; ++i) {
86+
callback(elements1[i].Get<double>(), elements2[i].Get<double>());
87+
}
88+
89+
return true;
90+
};
91+
92+
auto enumerateOneSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValue* elements1, const TUnboxedValuePod vector2) {
93+
const auto size = vector1.GetListLength();
94+
ui32 idx = 0;
95+
TUnboxedValue value;
96+
const auto it = vector2.GetListIterator();
97+
98+
while (it.Next(value)) {
99+
callback(elements1[idx++].Get<double>(), value.Get<double>());
100+
}
101+
102+
// Lenght mismatch
103+
if (it.Next(value) || idx != size)
104+
return false;
105+
106+
return true;
107+
};
108+
109+
auto enumerateNoSized = [&callback] (const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
110+
TUnboxedValue value1, value2;
111+
const auto it1 = vector1.GetListIterator();
112+
const auto it2 = vector2.GetListIterator();
113+
for (; it1.Next(value1) && it2.Next(value2);) {
114+
callback(value1.Get<double>(), value2.Get<double>());
115+
}
116+
117+
// Lenght mismatch
118+
if (it1.Next(value1) || it2.Next(value2))
119+
return false;
120+
121+
return true;
122+
};
123+
124+
const auto elements1 = vector1.GetElements();
125+
const auto elements2 = vector2.GetElements();
126+
if (elements1 && elements2) {
127+
if (!enumerateBothSized(vector1, elements1, vector2, elements2))
128+
return false;
129+
} else if (elements1) {
130+
if (!enumerateOneSized(vector1, elements1, vector2))
131+
return false;
132+
} else if (elements2) {
133+
if (!enumerateOneSized(vector2, elements2, vector1))
134+
return false;
135+
} else {
136+
if (!enumerateNoSized(vector1, vector2))
137+
return false;
138+
}
139+
140+
return true;
141+
}
142+
143+
std::optional<double> InnerProductDistance(const TUnboxedValuePod vector1, const TUnboxedValuePod vector2) {
144+
double ret = 0;
145+
146+
if (!EnumerateVectors(vector1, vector2, [&ret](double el1, double el2) { ret += el1 * el2;}))
147+
return {};
148+
149+
return ret;
150+
}
151+
152+
SIMPLE_STRICT_UDF(TInnerProductDistance, TOptional<double>(TOptional<TListType<double>>, TOptional<TListType<double>>)) {
153+
Y_UNUSED(valueBuilder);
154+
155+
if (!args[0].HasValue() || !args[1].HasValue())
156+
return {};
157+
158+
auto distance = InnerProductDistance(args[0], args[1]);
159+
if (!distance)
160+
return {};
161+
162+
return TUnboxedValuePod{distance.value()};
163+
}
164+
165+
SIMPLE_MODULE(TKnnModule,
166+
TFromBinaryString,
167+
TToBinaryString,
168+
TInnerProductDistance
169+
)
170+
171+
REGISTER_MODULES(TKnnModule)
172+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"test.test[DeserializationError]": [
3+
{
4+
"uri": "file://test.test_DeserializationError_/results.txt"
5+
}
6+
],
7+
"test.test[InnerProductDistance]": [
8+
{
9+
"uri": "file://test.test_InnerProductDistance_/results.txt"
10+
}
11+
],
12+
"test.test[LazyListSerialization]": [
13+
{
14+
"uri": "file://test.test_LazyListSerialization_/results.txt"
15+
}
16+
],
17+
"test.test[ListSerialization]": [
18+
{
19+
"uri": "file://test.test_ListSerialization_/results.txt"
20+
}
21+
]
22+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
[
2+
{
3+
"Write" = [
4+
{
5+
"Type" = [
6+
"ListType";
7+
[
8+
"StructType";
9+
[
10+
[
11+
"column0";
12+
[
13+
"OptionalType";
14+
[
15+
"ListType";
16+
[
17+
"DataType";
18+
"Double"
19+
]
20+
]
21+
]
22+
]
23+
]
24+
]
25+
];
26+
"Data" = [
27+
[
28+
#
29+
]
30+
]
31+
}
32+
]
33+
};
34+
{
35+
"Write" = [
36+
{
37+
"Type" = [
38+
"ListType";
39+
[
40+
"StructType";
41+
[
42+
[
43+
"column0";
44+
[
45+
"OptionalType";
46+
[
47+
"ListType";
48+
[
49+
"DataType";
50+
"Double"
51+
]
52+
]
53+
]
54+
]
55+
]
56+
]
57+
];
58+
"Data" = [
59+
[
60+
#
61+
]
62+
]
63+
}
64+
]
65+
}
66+
]

0 commit comments

Comments
 (0)