1
1
#pragma once
2
2
3
+ #include " knn-defines.h"
3
4
#include " knn-enumerator.h"
4
5
5
6
#include < ydb/library/yql/public/udf/udf_helpers.h>
11
12
using namespace NYql ;
12
13
using namespace NYql ::NUdf;
13
14
14
- enum EFormat : ui8 {
15
- FloatVector = 1
16
- };
17
-
18
- static constexpr size_t HeaderLen = sizeof (ui8);
19
-
20
- class TFloatVectorSerializer {
15
+ template <typename T, EFormat Format>
16
+ class TKnnVectorSerializer {
21
17
public:
22
18
static TUnboxedValue Serialize (const IValueBuilder* valueBuilder, const TUnboxedValue x) {
23
19
auto serialize = [&x] (IOutputStream& outStream) {
24
- EnumerateVector (x, [&outStream] (float element) { outStream.Write (&element, sizeof (float )); });
25
- const EFormat format = EFormat::FloatVector;
20
+ EnumerateVector (x, [&outStream] (float floatElement) {
21
+ T element = static_cast <T>(floatElement);
22
+ outStream.Write (&element, sizeof (T));
23
+ });
24
+ const EFormat format = Format;
26
25
outStream.Write (&format, HeaderLen);
27
26
};
28
27
29
28
if (x.HasFastListLength ()) {
30
- auto str = valueBuilder->NewStringNotFilled (HeaderLen + x.GetListLength () * sizeof (float ));
29
+ auto str = valueBuilder->NewStringNotFilled (HeaderLen + x.GetListLength () * sizeof (T ));
31
30
auto strRef = str.AsStringRef ();
32
31
TMemoryOutput memoryOutput (strRef.Data (), strRef.Size ());
33
32
@@ -46,45 +45,115 @@ class TFloatVectorSerializer {
46
45
const char * buf = str.Data ();
47
46
const size_t len = str.Size () - HeaderLen;
48
47
49
- if (len % sizeof (float ) != 0 )
48
+ if (Y_UNLIKELY ( len % sizeof (T ) != 0 ))
50
49
return {};
51
50
52
- const ui32 count = len / sizeof (float );
51
+ const ui32 count = len / sizeof (T );
53
52
54
53
TUnboxedValue* items = nullptr ;
55
54
auto res = valueBuilder->NewArray (count, items);
56
55
57
56
TMemoryInput inStr (buf, len);
58
57
for (ui32 i = 0 ; i < count; ++i) {
59
- float element;
60
- if (inStr.Read (&element, sizeof (float )) != sizeof (float ))
58
+ T element;
59
+ if (Y_UNLIKELY ( inStr.Read (&element, sizeof (T )) != sizeof (T) ))
61
60
return {};
62
- *items++ = TUnboxedValuePod{element};
61
+ *items++ = TUnboxedValuePod{static_cast < float >( element) };
63
62
}
64
63
65
64
return res.Release ();
66
65
}
67
66
68
- static const TArrayRef<const float > GetArray (const TStringRef& str) {
67
+ static const TArrayRef<const T > GetArray (const TStringRef& str) {
69
68
const char * buf = str.Data ();
70
69
const size_t len = str.Size () - HeaderLen;
71
70
72
- if (len % sizeof (float ) != 0 )
71
+ if (Y_UNLIKELY ( len % sizeof (T ) != 0 ))
73
72
return {};
74
73
75
- const ui32 count = len / sizeof (float );
74
+ const ui32 count = len / sizeof (T );
76
75
77
- return MakeArrayRef (reinterpret_cast <const float *>(buf), count);
76
+ return MakeArrayRef (reinterpret_cast <const T *>(buf), count);
78
77
}
79
78
};
80
79
80
+ // Encode all positive floats as bit 1, negative floats as bit 0.
81
+ // So 1024 float vector is serialized in 1024/8=128 bytes.
82
+ // Place all bits in ui64. So, only vector sizes divisible by 64 are supported.
83
+ // Max vector lenght is 32767.
84
+ class TKnnBitVectorSerializer {
85
+ public:
86
+ static TUnboxedValue Serialize (const IValueBuilder* valueBuilder, const TUnboxedValue x) {
87
+ auto serialize = [&x] (IOutputStream& outStream) {
88
+ ui64 accumulator = 0 ;
89
+ ui8 filledBits = 0 ;
90
+ ui64 lenght = 0 ;
91
+
92
+ EnumerateVector (x, [&] (float element) {
93
+ if (element > 0 )
94
+ accumulator |= 1ll << filledBits;
95
+
96
+ ++filledBits;
97
+ if (filledBits == 64 ) {
98
+ outStream.Write (&accumulator, sizeof (ui64));
99
+ lenght++;
100
+ accumulator = 0 ;
101
+ filledBits = 0 ;
102
+ }
103
+ });
104
+
105
+ // only vector sizes divisible by 64 are supported
106
+ if (Y_UNLIKELY (filledBits))
107
+ return false ;
108
+
109
+ // max vector lenght is 32767
110
+ if (Y_UNLIKELY (lenght > UINT16_MAX))
111
+ return false ;
112
+
113
+ const EFormat format = EFormat::BitVector;
114
+ outStream.Write (&format, HeaderLen);
115
+
116
+ return true ;
117
+ };
118
+
119
+ if (x.HasFastListLength ()) {
120
+ auto str = valueBuilder->NewStringNotFilled (HeaderLen + x.GetListLength () / 8 );
121
+ auto strRef = str.AsStringRef ();
122
+ TMemoryOutput memoryOutput (strRef.Data (), strRef.Size ());
81
123
82
- class TSerializerFacade {
124
+ if (Y_UNLIKELY (!serialize (memoryOutput)))
125
+ return {};
126
+
127
+ return str;
128
+ } else {
129
+ TString str;
130
+ TStringOutput stringOutput (str);
131
+
132
+ if (Y_UNLIKELY (!serialize (stringOutput)))
133
+ return {};
134
+
135
+ return valueBuilder->NewString (str);
136
+ }
137
+ }
138
+
139
+ static const TArrayRef<const ui64> GetArray64 (const TStringRef& str) {
140
+ const char * buf = str.Data ();
141
+ const size_t len = (str.Size () - HeaderLen) / sizeof (ui64);
142
+
143
+ return MakeArrayRef (reinterpret_cast <const ui64*>(buf), len);
144
+ }
145
+ };
146
+
147
+ class TKnnSerializerFacade {
83
148
public:
84
149
static TUnboxedValue Serialize (EFormat format, const IValueBuilder* valueBuilder, const TUnboxedValue x) {
85
150
switch (format) {
86
151
case EFormat::FloatVector:
87
- return TFloatVectorSerializer::Serialize (valueBuilder, x);
152
+ return TKnnVectorSerializer<float , EFormat::FloatVector>::Serialize (valueBuilder, x);
153
+ case EFormat::ByteVector:
154
+ return TKnnVectorSerializer<ui8, EFormat::ByteVector>::Serialize (valueBuilder, x);
155
+ case EFormat::BitVector:
156
+ return TKnnBitVectorSerializer::Serialize (valueBuilder, x);
88
157
default :
89
158
return {};
90
159
}
@@ -97,20 +166,29 @@ class TSerializerFacade {
97
166
const ui8 format = str.Data ()[str.Size () - HeaderLen];
98
167
switch (format) {
99
168
case EFormat::FloatVector:
100
- return TFloatVectorSerializer::Deserialize (valueBuilder, str);
169
+ return TKnnVectorSerializer<float , EFormat::FloatVector>::Deserialize (valueBuilder, str);
170
+ case EFormat::ByteVector:
171
+ return TKnnVectorSerializer<ui8, EFormat::ByteVector>::Deserialize (valueBuilder, str);
172
+ case EFormat::BitVector:
173
+ return {};
101
174
default :
102
175
return {};
103
176
}
104
177
}
105
178
106
- static const TArrayRef<const float > GetArray (const TStringRef& str) {
107
- if (str.Size () == 0 )
179
+ template <typename T>
180
+ static const TArrayRef<const T> GetArray (const TStringRef& str) {
181
+ if (Y_UNLIKELY (str.Size () == 0 ))
108
182
return {};
109
183
110
184
const ui8 format = str.Data ()[str.Size () - HeaderLen];
111
185
switch (format) {
112
186
case EFormat::FloatVector:
113
- return TFloatVectorSerializer::GetArray (str);
187
+ return TKnnVectorSerializer<T, EFormat::FloatVector>::GetArray (str);
188
+ case EFormat::ByteVector:
189
+ return TKnnVectorSerializer<T, EFormat::ByteVector>::GetArray (str);
190
+ case EFormat::BitVector:
191
+ return {};
114
192
default :
115
193
return {};
116
194
}
0 commit comments